mmq.cuh 73 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006
  1. #pragma once
  2. #include "common.cuh"
  3. #include "vecdotq.cuh"
  4. #include "mma.cuh"
  5. #include <climits>
  6. #include <cstdint>
  7. #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
  8. typedef void (*load_tiles_mmq_t)(
  9. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  10. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
  11. typedef void (*vec_dot_mmq_t)(
  12. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  13. const int * __restrict__ y, float * __restrict__ sum, const int & k0);
  14. typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
  15. struct block_q8_1_mmq {
  16. half2 ds[4];
  17. int8_t qs[4*QK8_1];
  18. };
  19. static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
  20. static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
  21. struct tile_x_sizes {
  22. int ql;
  23. int dm;
  24. int qh;
  25. int sc;
  26. };
  27. // get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  28. static constexpr __device__ int get_mmq_x_max_device() {
  29. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  30. return 64;
  31. #else
  32. #if __CUDA_ARCH__ >= CC_VOLTA
  33. #ifdef CUDA_USE_TENSOR_CORES
  34. return MMQ_MAX_BATCH_SIZE;
  35. #else
  36. return 128;
  37. #endif // CUDA_USE_TENSOR_CORES
  38. #else
  39. return 64;
  40. #endif // __CUDA_ARCH__ >= CC_VOLTA
  41. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  42. }
  43. // get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  44. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  45. static constexpr __device__ int get_mmq_y_device(int mmq_x) {
  46. return mmq_x >= 32 ? 128 : 64;
  47. }
  48. #else
  49. #if __CUDA_ARCH__ >= CC_VOLTA
  50. static constexpr __device__ int get_mmq_y_device(int mmq_x) {
  51. return mmq_x >= 32 ? 128 : 64;
  52. }
  53. #else
  54. static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
  55. return 64;
  56. }
  57. #endif // __CUDA_ARCH__ >= CC_VOLTA
  58. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  59. #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0, 0}
  60. #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0, 0}
  61. #define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0, 0}
  62. #define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0, 0}
  63. #define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0, 0}
  64. #define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0, mmq_y*WARP_SIZE/4 + mmq_y/4}
  65. #define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4}
  66. #define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
  67. #define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
  68. #define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
  69. #define GET_TILE_X_SIZES_BODY \
  70. return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
  71. type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \
  72. type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \
  73. type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \
  74. type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \
  75. type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \
  76. type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \
  77. type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
  78. type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
  79. type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
  80. tile_x_sizes{0, 0, 0, 0}
  81. static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
  82. GET_TILE_X_SIZES_BODY;
  83. }
  84. template <int mmq_y>
  85. static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
  86. GET_TILE_X_SIZES_BODY;
  87. }
  88. // ------------------------------------------------------------
  89. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
  90. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  91. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  92. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  93. const int kbx = threadIdx.x / QI4_0;
  94. const int kqsx = threadIdx.x % QI4_0;
  95. float * x_dmf = (float *) x_dm;
  96. #pragma unroll
  97. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  98. int i = i0 + threadIdx.y;
  99. if (need_check) {
  100. i = min(i, i_max);
  101. }
  102. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
  103. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  104. }
  105. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  106. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  107. #pragma unroll
  108. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  109. int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
  110. if (need_check) {
  111. i = min(i, i_max);
  112. }
  113. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
  114. x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
  115. }
  116. }
  117. template <int mmq_x, int mmq_y, int nwarps>
  118. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
  119. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  120. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  121. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  122. const float * x_df = (const float *) x_dm;
  123. const int * y_qs = (const int *) y + 4;
  124. const half2 * y_ds = (const half2 *) y;
  125. #pragma unroll
  126. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  127. const int j = j0 + threadIdx.y;
  128. #pragma unroll
  129. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  130. const int i = i0 + threadIdx.x;
  131. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  132. int u[2*VDR_Q4_0_Q8_1_MMQ];
  133. #pragma unroll
  134. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  135. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  136. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
  137. }
  138. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  139. (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
  140. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  141. }
  142. }
  143. }
  144. template <int mmq_x, int mmq_y, int nwarps>
  145. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
  146. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  147. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  148. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  149. typedef mma_int_A_I16K8 mma_A;
  150. typedef mma_int_B_J8K8 mma_B;
  151. typedef mma_int_C_I16J8 mma_C;
  152. const float * x_df = (const float *) x_dm;
  153. const int * y_qs = (const int *) y + 4;
  154. const half2 * y_ds = (const half2 *) y;
  155. mma_A A;
  156. float dA[mma_C::ne/2];
  157. const int i0 = threadIdx.y*mma_A::I;
  158. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  159. #pragma unroll
  160. for (int l = 0; l < mma_A::ne; ++l) {
  161. const int i = i0 + mma_A::get_i(l);
  162. const int k = k0 + mma_A::get_k(l) % QI4_0;
  163. const int shift = 4*(mma_A::get_k(l) / QI4_0);
  164. A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
  165. }
  166. #pragma unroll
  167. for (int l = 0; l < mma_C::ne/2; ++l) {
  168. const int i = i0 + mma_C::get_i(2*l);
  169. dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
  170. }
  171. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  172. mma_C C;
  173. mma_B B;
  174. half2 dsB[mma_C::ne/2];
  175. #pragma unroll
  176. for (int l = 0; l < mma_B::ne; ++l) {
  177. const int j = j0 + mma_B::get_j(l);
  178. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  179. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  180. }
  181. #pragma unroll
  182. for (int l = 0; l < mma_C::ne/2; ++l) {
  183. const int j = j0 + mma_C::get_j(l);
  184. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  185. }
  186. C.mma_K8(A, B);
  187. #pragma unroll
  188. for (int l = 0; l < mma_C::ne; ++l) {
  189. sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
  190. }
  191. }
  192. }
  193. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
  194. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  195. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  196. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  197. const int kbx = threadIdx.x / QI4_1;
  198. const int kqsx = threadIdx.x % QI4_1;
  199. #pragma unroll
  200. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  201. int i = i0 + threadIdx.y;
  202. if (need_check) {
  203. i = min(i, i_max);
  204. }
  205. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
  206. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  207. }
  208. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  209. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  210. #pragma unroll
  211. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  212. int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
  213. if (need_check) {
  214. i = min(i, i_max);
  215. }
  216. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
  217. x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
  218. }
  219. }
  220. template <int mmq_x, int mmq_y, int nwarps>
  221. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
  222. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  223. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  224. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  225. const int * y_qs = (const int *) y + 4;
  226. const half2 * y_ds = (const half2 *) y;
  227. #pragma unroll
  228. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  229. const int j = j0 + threadIdx.y;
  230. #pragma unroll
  231. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  232. const int i = i0 + threadIdx.x;
  233. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  234. int u[2*VDR_Q4_1_Q8_1_MMQ];
  235. #pragma unroll
  236. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  237. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  238. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
  239. }
  240. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  241. (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
  242. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  243. }
  244. }
  245. }
  246. template <int mmq_x, int mmq_y, int nwarps>
  247. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
  248. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  249. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  250. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  251. typedef mma_int_A_I16K8 mma_A;
  252. typedef mma_int_B_J8K8 mma_B;
  253. typedef mma_int_C_I16J8 mma_C;
  254. const int * y_qs = (const int *) y + 4;
  255. const half2 * y_ds = (const half2 *) y;
  256. mma_A A;
  257. half2 dmA[mma_C::ne/2];
  258. const int i0 = threadIdx.y*mma_A::I;
  259. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  260. #pragma unroll
  261. for (int l = 0; l < mma_A::ne; ++l) {
  262. const int i = i0 + mma_A::get_i(l);
  263. const int k = k0 + mma_A::get_k(l) % QI4_0;
  264. const int shift = 4*(mma_A::get_k(l) / QI4_0);
  265. A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
  266. }
  267. #pragma unroll
  268. for (int l = 0; l < mma_C::ne/2; ++l) {
  269. const int i = i0 + mma_C::get_i(2*l);
  270. dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
  271. }
  272. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  273. mma_C C;
  274. mma_B B;
  275. half2 dsB[mma_C::ne/2];
  276. #pragma unroll
  277. for (int l = 0; l < mma_B::ne; ++l) {
  278. const int j = j0 + mma_B::get_j(l);
  279. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  280. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  281. }
  282. #pragma unroll
  283. for (int l = 0; l < mma_C::ne/2; ++l) {
  284. const int j = j0 + mma_C::get_j(l);
  285. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  286. }
  287. C.mma_K8(A, B);
  288. #pragma unroll
  289. for (int l = 0; l < mma_C::ne; ++l) {
  290. const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
  291. sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
  292. }
  293. }
  294. }
  295. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
  296. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  297. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  298. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  299. const int kbx = threadIdx.x / QI5_0;
  300. const int kqsx = threadIdx.x % QI5_0;
  301. #pragma unroll
  302. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  303. int i = i0 + threadIdx.y;
  304. if (need_check) {
  305. i = min(i, i_max);
  306. }
  307. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
  308. const int ql = get_int_from_uint8(bxi->qs, kqsx);
  309. const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
  310. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  311. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  312. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  313. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  314. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  315. qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
  316. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  317. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  318. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  319. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  320. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  321. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  322. qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
  323. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  324. }
  325. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  326. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  327. float * x_dmf = (float *) x_dm;
  328. #pragma unroll
  329. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  330. int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
  331. if (need_check) {
  332. i = min(i, i_max);
  333. }
  334. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
  335. x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
  336. }
  337. }
  338. template <int mmq_x, int mmq_y, int nwarps>
  339. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
  340. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  341. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  342. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  343. const float * x_dmf = (const float *) x_dm;
  344. const int * y_qs = (const int *) y + 4;
  345. const float * y_df = (const float *) y;
  346. #pragma unroll
  347. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  348. const int j = j0 + threadIdx.y;
  349. #pragma unroll
  350. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  351. const int i = i0 + threadIdx.x;
  352. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  353. const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
  354. int u[2*VDR_Q5_0_Q8_1_MMQ];
  355. #pragma unroll
  356. for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
  357. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  358. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
  359. }
  360. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
  361. (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  362. }
  363. }
  364. }
  365. template <int mmq_x, int mmq_y, int nwarps>
  366. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
  367. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  368. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  369. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  370. typedef mma_int_A_I16K8 mma_A;
  371. typedef mma_int_B_J8K8 mma_B;
  372. typedef mma_int_C_I16J8 mma_C;
  373. const float * x_df = (const float *) x_dm;
  374. const int * y_qs = (const int *) y + 4;
  375. const float * y_df = (const float *) y;
  376. mma_A A;
  377. float dA[mma_C::ne/2];
  378. const int i0 = threadIdx.y*mma_A::I;
  379. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  380. #pragma unroll
  381. for (int l = 0; l < mma_A::ne; ++l) {
  382. const int i = i0 + mma_A::get_i(l);
  383. const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
  384. A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
  385. }
  386. #pragma unroll
  387. for (int l = 0; l < mma_C::ne/2; ++l) {
  388. const int i = i0 + mma_C::get_i(2*l);
  389. dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
  390. }
  391. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  392. mma_C C;
  393. mma_B B;
  394. float dB[mma_C::ne/2];
  395. #pragma unroll
  396. for (int l = 0; l < mma_B::ne; ++l) {
  397. const int j = j0 + mma_B::get_j(l);
  398. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  399. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  400. }
  401. #pragma unroll
  402. for (int l = 0; l < mma_C::ne/2; ++l) {
  403. const int j = j0 + mma_C::get_j(l);
  404. dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  405. }
  406. C.mma_K8(A, B);
  407. #pragma unroll
  408. for (int l = 0; l < mma_C::ne; ++l) {
  409. sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
  410. }
  411. }
  412. }
  413. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
  414. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  415. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  416. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  417. const int kbx = threadIdx.x / QI5_1;
  418. const int kqsx = threadIdx.x % QI5_1;
  419. #pragma unroll
  420. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  421. int i = i0 + threadIdx.y;
  422. if (need_check) {
  423. i = min(i, i_max);
  424. }
  425. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
  426. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  427. const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
  428. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  429. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  430. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  431. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  432. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  433. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  434. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  435. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  436. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  437. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  438. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  439. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  440. }
  441. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  442. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  443. #pragma unroll
  444. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  445. int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
  446. if (need_check) {
  447. i = min(i, i_max);
  448. }
  449. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
  450. x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
  451. }
  452. }
  453. template <int mmq_x, int mmq_y, int nwarps>
  454. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
  455. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  456. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  457. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  458. const int * y_qs = (const int *) y + 4;
  459. const half2 * y_ds = (const half2 *) y;
  460. #pragma unroll
  461. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  462. const int j = j0 + threadIdx.y;
  463. #pragma unroll
  464. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  465. const int i = i0 + threadIdx.x;
  466. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  467. const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
  468. int u[2*VDR_Q5_1_Q8_1_MMQ];
  469. #pragma unroll
  470. for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
  471. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  472. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
  473. }
  474. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  475. (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  476. }
  477. }
  478. }
  479. template <int mmq_x, int mmq_y, int nwarps>
  480. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
  481. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  482. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  483. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  484. typedef mma_int_A_I16K8 mma_A;
  485. typedef mma_int_B_J8K8 mma_B;
  486. typedef mma_int_C_I16J8 mma_C;
  487. const int * y_qs = (const int *) y + 4;
  488. const half2 * y_ds = (const half2 *) y;
  489. mma_A A;
  490. half2 dmA[mma_C::ne/2];
  491. const int i0 = threadIdx.y*mma_A::I;
  492. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  493. #pragma unroll
  494. for (int l = 0; l < mma_A::ne; ++l) {
  495. const int i = i0 + mma_A::get_i(l);
  496. const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
  497. A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
  498. }
  499. #pragma unroll
  500. for (int l = 0; l < mma_C::ne/2; ++l) {
  501. const int i = i0 + mma_C::get_i(2*l);
  502. dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
  503. }
  504. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  505. mma_C C;
  506. mma_B B;
  507. half2 dsB[mma_C::ne/2];
  508. #pragma unroll
  509. for (int l = 0; l < mma_B::ne; ++l) {
  510. const int j = j0 + mma_B::get_j(l);
  511. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  512. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  513. }
  514. #pragma unroll
  515. for (int l = 0; l < mma_C::ne/2; ++l) {
  516. const int j = j0 + mma_C::get_j(l);
  517. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  518. }
  519. C.mma_K8(A, B);
  520. #pragma unroll
  521. for (int l = 0; l < mma_C::ne; ++l) {
  522. const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
  523. sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
  524. }
  525. }
  526. }
  527. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
  528. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  529. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  530. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  531. const int kbx = threadIdx.x / QI8_0;
  532. const int kqsx = threadIdx.x % QI8_0;
  533. float * x_dmf = (float *) x_dm;
  534. #pragma unroll
  535. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  536. int i = i0 + threadIdx.y;
  537. if (need_check) {
  538. i = min(i, i_max);
  539. }
  540. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
  541. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
  542. }
  543. const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
  544. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  545. #pragma unroll
  546. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
  547. int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
  548. if (need_check) {
  549. i = min(i, i_max);
  550. }
  551. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
  552. x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
  553. }
  554. }
  555. template <int mmq_x, int mmq_y, int nwarps>
  556. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
  557. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  558. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  559. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  560. const float * x_dmf = (const float *) x_dm;
  561. const int * y_qs = (const int *) y + 4;
  562. const float * y_df = (const float *) y;
  563. #pragma unroll
  564. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  565. const int j = j0 + threadIdx.y;
  566. #pragma unroll
  567. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  568. const int i = i0 + threadIdx.x;
  569. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
  570. (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
  571. y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
  572. }
  573. }
  574. }
  575. template <int mmq_x, int mmq_y, int nwarps>
  576. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
  577. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  578. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  579. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  580. typedef mma_int_A_I16K8 mma_A;
  581. typedef mma_int_B_J8K8 mma_B;
  582. typedef mma_int_C_I16J8 mma_C;
  583. const float * x_df = (const float *) x_dm;
  584. const int * y_qs = (const int *) y + 4;
  585. const float * y_df = (const float *) y;
  586. mma_A A;
  587. float dA[mma_C::ne/2];
  588. const int i0 = threadIdx.y*mma_A::I;
  589. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  590. #pragma unroll
  591. for (int l = 0; l < mma_A::ne; ++l) {
  592. const int i = i0 + mma_A::get_i(l);
  593. const int k = k0 + mma_A::get_k(l);
  594. A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
  595. }
  596. #pragma unroll
  597. for (int l = 0; l < mma_C::ne/2; ++l) {
  598. const int i = i0 + mma_C::get_i(2*l);
  599. dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
  600. }
  601. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  602. mma_C C;
  603. mma_B B;
  604. float dB[mma_C::ne/2];
  605. #pragma unroll
  606. for (int l = 0; l < mma_B::ne; ++l) {
  607. const int j = j0 + mma_B::get_j(l);
  608. const int k = k0 + mma_B::get_k(l);
  609. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  610. }
  611. #pragma unroll
  612. for (int l = 0; l < mma_C::ne/2; ++l) {
  613. const int j = j0 + mma_C::get_j(l);
  614. dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
  615. }
  616. C.mma_K8(A, B);
  617. #pragma unroll
  618. for (int l = 0; l < mma_C::ne; ++l) {
  619. sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
  620. }
  621. }
  622. }
  623. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
  624. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  625. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  626. GGML_UNUSED(x_qh);
  627. const int kbx = threadIdx.x / QI2_K;
  628. const int kqsx = threadIdx.x % QI2_K;
  629. #pragma unroll
  630. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  631. int i = i0 + threadIdx.y;
  632. if (need_check) {
  633. i = min(i, i_max);
  634. }
  635. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
  636. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  637. }
  638. const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
  639. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  640. #pragma unroll
  641. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
  642. int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  643. if (need_check) {
  644. i = min(i, i_max);
  645. }
  646. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd;
  647. x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
  648. }
  649. #pragma unroll
  650. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  651. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  652. if (need_check) {
  653. i = min(i, i_max);
  654. }
  655. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4);
  656. x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4));
  657. }
  658. }
  659. template <int mmq_x, int mmq_y, int nwarps>
  660. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
  661. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  662. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  663. GGML_UNUSED(x_qh);
  664. const int * y_qs = (const int *) y + 4;
  665. const float * y_df = (const float *) y;
  666. #pragma unroll
  667. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  668. const int j = j0 + threadIdx.y;
  669. #pragma unroll
  670. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  671. const int i = i0 + threadIdx.x;
  672. const int kbx = k0 / QI2_K;
  673. const int ky = (k0 % QI2_K) * QR2_K;
  674. int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
  675. const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
  676. const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
  677. #pragma unroll
  678. for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
  679. v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
  680. }
  681. const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
  682. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
  683. v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales,
  684. x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
  685. }
  686. }
  687. }
  688. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
  689. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  690. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  691. const int kbx = threadIdx.x / QI3_K;
  692. const int kqsx = threadIdx.x % QI3_K;
  693. #pragma unroll
  694. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  695. int i = i0 + threadIdx.y;
  696. if (need_check) {
  697. i = min(i, i_max);
  698. }
  699. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
  700. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  701. }
  702. const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
  703. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  704. float * x_dmf = (float *) x_dm;
  705. #pragma unroll
  706. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
  707. int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  708. if (need_check) {
  709. i = min(i, i_max);
  710. }
  711. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
  712. x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
  713. }
  714. #pragma unroll
  715. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
  716. int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2);
  717. if (need_check) {
  718. i = min(i, i_max);
  719. }
  720. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2);
  721. // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
  722. x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2));
  723. }
  724. #pragma unroll
  725. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  726. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  727. if (need_check) {
  728. i = min(i, i_max);
  729. }
  730. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
  731. const int ksc = threadIdx.x % (QI3_K/4);
  732. const int ksc_low = ksc % (QI3_K/8);
  733. const int shift_low = 4 * (ksc / (QI3_K/8));
  734. const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  735. const int ksc_high = QI3_K/8;
  736. const int shift_high = 2 * ksc;
  737. const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  738. const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
  739. x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
  740. }
  741. }
  742. template <int mmq_x, int mmq_y, int nwarps>
  743. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
  744. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  745. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  746. const float * x_dmf = (const float *) x_dm;
  747. const int * y_qs = (const int *) y + 4;
  748. const float * y_df = (const float *) y;
  749. #pragma unroll
  750. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  751. const int j = j0 + threadIdx.y;
  752. #pragma unroll
  753. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  754. const int i = i0 + threadIdx.x;
  755. const int kbx = k0 / QI3_K;
  756. const int ky = (k0 % QI3_K) * QR3_K;
  757. const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  758. int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
  759. #pragma unroll
  760. for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
  761. const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
  762. const int shift = 2 * ((ky % 32) / 8);
  763. const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
  764. const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
  765. const int vlh = (vh << 2) & 0x04040404;
  766. v[l] = __vsubss4(vll, vlh);
  767. }
  768. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
  769. v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
  770. x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
  771. }
  772. }
  773. }
  774. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
  775. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  776. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  777. GGML_UNUSED(x_qh);
  778. const int kbx = 0; // threadIdx.x / QI4_K
  779. const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
  780. #pragma unroll
  781. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  782. int i = i0 + threadIdx.y;
  783. if (need_check) {
  784. i = min(i, i_max);
  785. }
  786. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
  787. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  788. }
  789. const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
  790. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  791. #pragma unroll
  792. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
  793. int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  794. if (need_check) {
  795. i = min(i, i_max);
  796. }
  797. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
  798. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
  799. }
  800. #pragma unroll
  801. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  802. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  803. if (need_check) {
  804. i = min(i, i_max);
  805. }
  806. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
  807. const int * scales = (const int *) bxi->scales;
  808. const int ksc = threadIdx.x % (WARP_SIZE/8);
  809. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  810. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  811. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  812. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  813. }
  814. }
  815. template <int mmq_x, int mmq_y, int nwarps>
  816. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
  817. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  818. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  819. GGML_UNUSED(x_qh);
  820. const int * y_qs = (const int *) y + 4;
  821. const half2 * y_ds = (const half2 *) y;
  822. #pragma unroll
  823. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  824. const int j = j0 + threadIdx.y;
  825. #pragma unroll
  826. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  827. const int i = i0 + threadIdx.x;
  828. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
  829. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
  830. &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
  831. x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
  832. }
  833. }
  834. }
  835. template <int mmq_x, int mmq_y, int nwarps>
  836. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
  837. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  838. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  839. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  840. typedef mma_int_A_I16K8 mma_A;
  841. typedef mma_int_B_J8K8 mma_B;
  842. typedef mma_int_C_I16J8 mma_C;
  843. const int * y_qs = (const int *) y + 4;
  844. const half2 * y_ds = (const half2 *) y;
  845. const int i0 = threadIdx.y*mma_A::I;
  846. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  847. mma_A A[2];
  848. int scA[mma_C::ne/2][2];
  849. int mA[mma_C::ne/2][2];
  850. half2 dmA[mma_C::ne/2];
  851. #pragma unroll
  852. for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
  853. #pragma unroll
  854. for (int l = 0; l < mma_A::ne; ++l) {
  855. const int i = i0 + mma_A::get_i(l);
  856. const int k = k0 + mma_A::get_k(l);
  857. A[kvdr/4].x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
  858. }
  859. #pragma unroll
  860. for (int l = 0; l < mma_C::ne/2; ++l) {
  861. const int i = i0 + mma_C::get_i(2*l);
  862. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  863. const uint8_t * m = sc + 8;
  864. scA[l][kvdr/4] = sc[kvdr/4];
  865. mA[l][kvdr/4] = m[kvdr/4];
  866. }
  867. }
  868. #pragma unroll
  869. for (int l = 0; l < mma_C::ne/2; ++l) {
  870. const int i = i0 + mma_C::get_i(2*l);
  871. dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
  872. }
  873. #pragma unroll
  874. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  875. float tmpd[mma_C::ne] = {0.0f};
  876. float tmpm[mma_C::ne] = {0.0f};
  877. #pragma unroll
  878. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  879. mma_C C;
  880. mma_B B;
  881. half2 dsB[mma_C::ne/2];
  882. #pragma unroll
  883. for (int l = 0; l < mma_B::ne; ++l) {
  884. const int j = j0 + mma_B::get_j(l);
  885. const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
  886. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  887. }
  888. #pragma unroll
  889. for (int l = 0; l < mma_C::ne/2; ++l) {
  890. const int j = j0 + mma_C::get_j(l);
  891. dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  892. }
  893. C.mma_K8(A[kvdr/4], B);
  894. #pragma unroll
  895. for (int l = 0; l < mma_C::ne; ++l) {
  896. tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
  897. tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
  898. }
  899. }
  900. #pragma unroll
  901. for (int l = 0; l < mma_C::ne; ++l) {
  902. sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
  903. }
  904. }
  905. }
  906. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
  907. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  908. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  909. GGML_UNUSED(x_qh);
  910. const int kbx = 0; // threadIdx.x / QI5_K
  911. const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
  912. #pragma unroll
  913. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  914. int i = i0 + threadIdx.y;
  915. if (need_check) {
  916. i = min(i, i_max);
  917. }
  918. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
  919. const int ky = QR5_K*kqsx;
  920. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  921. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  922. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  923. const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
  924. const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  925. const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  926. const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
  927. const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
  928. x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  929. x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  930. }
  931. const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
  932. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  933. #pragma unroll
  934. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
  935. int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  936. if (need_check) {
  937. i = min(i, i_max);
  938. }
  939. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
  940. x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
  941. }
  942. #pragma unroll
  943. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  944. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  945. if (need_check) {
  946. i = min(i, i_max);
  947. }
  948. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
  949. const int * scales = (const int *) bxi->scales;
  950. const int ksc = threadIdx.x % (WARP_SIZE/8);
  951. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  952. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  953. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  954. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  955. }
  956. }
  957. template <int mmq_x, int mmq_y, int nwarps>
  958. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
  959. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  960. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  961. GGML_UNUSED(x_qh);
  962. const int * y_qs = (const int *) y + 4;
  963. const half2 * y_ds = (const half2 *) y;
  964. #pragma unroll
  965. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  966. const int j = j0 + threadIdx.y;
  967. #pragma unroll
  968. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  969. const int i = i0 + threadIdx.x;
  970. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  971. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
  972. &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
  973. x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
  974. }
  975. }
  976. }
  977. template <int mmq_x, int mmq_y, int nwarps>
  978. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
  979. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  980. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  981. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  982. typedef mma_int_A_I16K8 mma_A;
  983. typedef mma_int_B_J8K8 mma_B;
  984. typedef mma_int_C_I16J8 mma_C;
  985. const int * y_qs = (const int *) y + 4;
  986. const half2 * y_ds = (const half2 *) y;
  987. const int i0 = threadIdx.y*mma_A::I;
  988. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  989. mma_A A[2];
  990. int scA[mma_C::ne/2][2];
  991. int mA[mma_C::ne/2][2];
  992. half2 dmA[mma_C::ne/2];
  993. #pragma unroll
  994. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  995. #pragma unroll
  996. for (int l = 0; l < mma_A::ne; ++l) {
  997. const int i = i0 + mma_A::get_i(l);
  998. const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
  999. A[kvdr/4].x[l] = x_ql[i*(QR5_K*WARP_SIZE + 1) + k];
  1000. }
  1001. #pragma unroll
  1002. for (int l = 0; l < mma_C::ne/2; ++l) {
  1003. const int i = i0 + mma_C::get_i(2*l);
  1004. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  1005. const uint8_t * m = sc + 8;
  1006. scA[l][kvdr/4] = sc[kvdr/4];
  1007. mA[l][kvdr/4] = m[kvdr/4];
  1008. }
  1009. }
  1010. #pragma unroll
  1011. for (int l = 0; l < mma_C::ne/2; ++l) {
  1012. const int i = i0 + mma_C::get_i(2*l);
  1013. dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
  1014. }
  1015. #pragma unroll
  1016. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  1017. float tmpd[mma_C::ne] = {0.0f};
  1018. float tmpm[mma_C::ne] = {0.0f};
  1019. #pragma unroll
  1020. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  1021. mma_C C;
  1022. mma_B B;
  1023. half2 dsB[mma_C::ne/2];
  1024. #pragma unroll
  1025. for (int l = 0; l < mma_B::ne; ++l) {
  1026. const int j = j0 + mma_B::get_j(l);
  1027. const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
  1028. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  1029. }
  1030. #pragma unroll
  1031. for (int l = 0; l < mma_C::ne/2; ++l) {
  1032. const int j = j0 + mma_C::get_j(l);
  1033. dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1034. }
  1035. C.mma_K8(A[kvdr/4], B);
  1036. #pragma unroll
  1037. for (int l = 0; l < mma_C::ne; ++l) {
  1038. tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
  1039. tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
  1040. }
  1041. }
  1042. #pragma unroll
  1043. for (int l = 0; l < mma_C::ne; ++l) {
  1044. sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
  1045. }
  1046. }
  1047. }
  1048. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
  1049. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  1050. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  1051. GGML_UNUSED(x_qh);
  1052. const int kbx = 0; // threadIdx.x / QI6_K
  1053. const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
  1054. #pragma unroll
  1055. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1056. int i = i0 + threadIdx.y;
  1057. if (need_check) {
  1058. i = min(i, i_max);
  1059. }
  1060. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
  1061. const int ky = QR6_K*kqsx;
  1062. const int ql = get_int_from_uint8(bxi->ql, kqsx);
  1063. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  1064. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  1065. const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
  1066. const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
  1067. const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
  1068. const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
  1069. const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
  1070. x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  1071. x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  1072. }
  1073. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  1074. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  1075. float * x_dmf = (float *) x_dm;
  1076. #pragma unroll
  1077. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  1078. int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  1079. if (need_check) {
  1080. i = min(i, i_max);
  1081. }
  1082. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
  1083. x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
  1084. }
  1085. #pragma unroll
  1086. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1087. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1088. if (need_check) {
  1089. i = min(i, i_max);
  1090. }
  1091. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
  1092. x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
  1093. }
  1094. }
  1095. template <int mmq_x, int mmq_y, int nwarps>
  1096. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
  1097. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  1098. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1099. GGML_UNUSED(x_qh);
  1100. const float * x_dmf = (const float *) x_dm;
  1101. const int * y_qs = (const int *) y + 4;
  1102. const float * y_df = (const float *) y;
  1103. #pragma unroll
  1104. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1105. const int j = j0 + threadIdx.y;
  1106. #pragma unroll
  1107. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1108. const int i = i0 + threadIdx.x;
  1109. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
  1110. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
  1111. &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
  1112. x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
  1113. }
  1114. }
  1115. }
  1116. template <int mmq_x, int mmq_y, int nwarps>
  1117. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
  1118. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  1119. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1120. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  1121. typedef mma_int_A_I16K4 mma_A;
  1122. typedef mma_int_B_J8K4 mma_B;
  1123. typedef mma_int_C_I16J8 mma_C;
  1124. const float * x_df = (const float *) x_dm;
  1125. const int * y_qs = (const int *) y + 4;
  1126. const float * y_df = (const float *) y;
  1127. const int i0 = threadIdx.y*mma_A::I;
  1128. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  1129. mma_A A[4];
  1130. int scA[mma_C::ne/2][4];
  1131. float dA[mma_C::ne/2];
  1132. #pragma unroll
  1133. for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
  1134. #pragma unroll
  1135. for (int l = 0; l < mma_A::ne; ++l) {
  1136. const int i = i0 + mma_A::get_i(l);
  1137. const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
  1138. A[kvdr/2 + 0].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + 0];
  1139. A[kvdr/2 + 1].x[l] = x_ql[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
  1140. }
  1141. #pragma unroll
  1142. for (int l = 0; l < mma_C::ne/2; ++l) {
  1143. const int i = i0 + mma_C::get_i(2*l);
  1144. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
  1145. scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
  1146. scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
  1147. }
  1148. }
  1149. #pragma unroll
  1150. for (int l = 0; l < mma_C::ne/2; ++l) {
  1151. const int i = i0 + mma_C::get_i(2*l);
  1152. dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
  1153. }
  1154. #pragma unroll
  1155. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  1156. float tmp[mma_C::ne] = {0.0f};
  1157. #pragma unroll
  1158. for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
  1159. mma_C C[2];
  1160. mma_B B[2];
  1161. float dB[mma_C::ne/2];
  1162. #pragma unroll
  1163. for (int l = 0; l < mma_B::ne; ++l) {
  1164. const int j = j0 + mma_B::get_j(l);
  1165. const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
  1166. B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
  1167. B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
  1168. }
  1169. #pragma unroll
  1170. for (int l = 0; l < mma_C::ne/2; ++l) {
  1171. const int j = j0 + mma_C::get_j(l);
  1172. dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1173. }
  1174. C[0].mma_K4(A[kvdr/2 + 0], B[0]);
  1175. C[1].mma_K4(A[kvdr/2 + 1], B[1]);
  1176. #pragma unroll
  1177. for (int l = 0; l < mma_C::ne; ++l) {
  1178. tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
  1179. }
  1180. }
  1181. #pragma unroll
  1182. for (int l = 0; l < mma_C::ne; ++l) {
  1183. sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
  1184. }
  1185. }
  1186. }
  1187. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1188. static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
  1189. #pragma unroll
  1190. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1191. const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
  1192. if (j >= ne1) {
  1193. return;
  1194. }
  1195. #pragma unroll
  1196. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1197. const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
  1198. if (need_check && i >= ne0) {
  1199. continue;
  1200. }
  1201. dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  1202. }
  1203. }
  1204. }
  1205. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1206. static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
  1207. typedef mma_int_C_I16J8 mma_C;
  1208. const int i0 = threadIdx.y*mma_C::I;
  1209. static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
  1210. #pragma unroll
  1211. for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
  1212. #pragma unroll
  1213. for (int l = 0; l < mma_C::ne; ++l) {
  1214. const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
  1215. if (j >= ne1) {
  1216. continue;
  1217. }
  1218. const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
  1219. if (need_check && i >= ne0) {
  1220. continue;
  1221. }
  1222. dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
  1223. }
  1224. }
  1225. }
  1226. // -------------------------------------------------------------------------------------------------------------------------------------
  1227. template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
  1228. struct mmq_type_traits;
  1229. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1230. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
  1231. static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
  1232. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
  1233. #ifdef INT8_MMA_AVAILABLE
  1234. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1235. static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1236. #else
  1237. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1238. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1239. #endif // INT8_MMA_AVAILABLE
  1240. };
  1241. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1242. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
  1243. static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
  1244. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
  1245. #ifdef INT8_MMA_AVAILABLE
  1246. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1247. static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1248. #else
  1249. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1250. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1251. #endif // INT8_MMA_AVAILABLE
  1252. };
  1253. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1254. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
  1255. static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
  1256. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
  1257. #ifdef INT8_MMA_AVAILABLE
  1258. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1259. static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1260. #else
  1261. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1262. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1263. #endif // INT8_MMA_AVAILABLE
  1264. };
  1265. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1266. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
  1267. static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
  1268. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
  1269. #ifdef INT8_MMA_AVAILABLE
  1270. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1271. static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1272. #else
  1273. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1274. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1275. #endif // INT8_MMA_AVAILABLE
  1276. };
  1277. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1278. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
  1279. static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
  1280. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
  1281. #ifdef INT8_MMA_AVAILABLE
  1282. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1283. static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1284. #else
  1285. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1286. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1287. #endif // INT8_MMA_AVAILABLE
  1288. };
  1289. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1290. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
  1291. static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
  1292. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
  1293. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  1294. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1295. };
  1296. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1297. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
  1298. static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
  1299. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
  1300. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  1301. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1302. };
  1303. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1304. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
  1305. static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
  1306. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
  1307. #ifdef INT8_MMA_AVAILABLE
  1308. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1309. static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1310. #else
  1311. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1312. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1313. #endif // INT8_MMA_AVAILABLE
  1314. };
  1315. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1316. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
  1317. static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
  1318. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
  1319. #ifdef INT8_MMA_AVAILABLE
  1320. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1321. static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1322. #else
  1323. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1324. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1325. #endif // INT8_MMA_AVAILABLE
  1326. };
  1327. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1328. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
  1329. static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
  1330. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
  1331. #ifdef INT8_MMA_AVAILABLE
  1332. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1333. static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1334. #else
  1335. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1336. static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1337. #endif // INT8_MMA_AVAILABLE
  1338. };
  1339. static int mmq_need_sum(const ggml_type type_x) {
  1340. switch (type_x) {
  1341. case GGML_TYPE_Q4_0:
  1342. case GGML_TYPE_Q4_1:
  1343. return true;
  1344. case GGML_TYPE_Q5_0:
  1345. return false;
  1346. case GGML_TYPE_Q5_1:
  1347. return true;
  1348. case GGML_TYPE_Q8_0:
  1349. case GGML_TYPE_Q2_K:
  1350. case GGML_TYPE_Q3_K:
  1351. return false;
  1352. case GGML_TYPE_Q4_K:
  1353. case GGML_TYPE_Q5_K:
  1354. return true;
  1355. case GGML_TYPE_Q6_K:
  1356. return false;
  1357. default:
  1358. GGML_ASSERT(false);
  1359. break;
  1360. }
  1361. return false;
  1362. }
  1363. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  1364. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1365. #if defined(RDNA3) || defined(RDNA2)
  1366. __launch_bounds__(WARP_SIZE*nwarps, 2)
  1367. #endif // defined(RDNA3) || defined(RDNA2)
  1368. #else
  1369. #if __CUDA_ARCH__ >= CC_VOLTA
  1370. __launch_bounds__(WARP_SIZE*nwarps, 1)
  1371. #else
  1372. __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2)
  1373. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1374. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1375. static __global__ void mul_mat_q(
  1376. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
  1377. const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
  1378. // Skip unused template specializations for faster compilation:
  1379. if (mmq_x > get_mmq_x_max_device()) {
  1380. NO_DEVICE_CODE;
  1381. return;
  1382. }
  1383. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  1384. constexpr int qr = ggml_cuda_type_traits<type>::qr;
  1385. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  1386. constexpr int mmq_y = get_mmq_y_device(mmq_x);
  1387. constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
  1388. constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
  1389. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
  1390. constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
  1391. constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
  1392. extern __shared__ char data_mul_mat_q[];
  1393. int * tile_x_ql = (int *) data_mul_mat_q;
  1394. half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
  1395. int * tile_x_qh = (int *) (tile_x_dm + txs.dm);
  1396. int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
  1397. int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
  1398. const int blocks_per_row_x = ne00 / qk;
  1399. const int blocks_per_warp = WARP_SIZE / qi;
  1400. const int & ne1 = ne11;
  1401. const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
  1402. const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
  1403. float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
  1404. for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
  1405. load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
  1406. #pragma unroll
  1407. for (int kr = 0; kr < qr; ++kr) {
  1408. const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
  1409. #pragma unroll
  1410. for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
  1411. int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
  1412. tile_y[l] = by0[l];
  1413. }
  1414. __syncthreads();
  1415. // #pragma unroll // unrolling this loop causes too much register pressure
  1416. for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
  1417. vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0);
  1418. }
  1419. __syncthreads();
  1420. }
  1421. }
  1422. write_back(sum, dst, ne0, ne1);
  1423. }
  1424. struct mmq_args {
  1425. const char * x; const char * y; float * dst;
  1426. int64_t ne00; int64_t ne01; int64_t stride01;
  1427. int64_t ne10; int64_t ne11; int64_t stride11;
  1428. int64_t ne0;
  1429. };
  1430. template <ggml_type type, int mmq_x, int nwarps>
  1431. static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
  1432. const int id = ggml_cuda_get_device();
  1433. const int cc = ggml_cuda_info().devices[id].cc;
  1434. const int mmq_y = get_mmq_y_host(cc, mmq_x);
  1435. const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
  1436. const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
  1437. const dim3 block_nums(block_num_x, block_num_y, 1);
  1438. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  1439. const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
  1440. const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
  1441. const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
  1442. const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
  1443. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  1444. static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
  1445. if (!shmem_limit_raised[id]) {
  1446. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  1447. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  1448. shmem_limit_raised[id] = true;
  1449. }
  1450. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  1451. if (args.ne01 % mmq_y == 0) {
  1452. const bool need_check = false;
  1453. mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
  1454. (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1455. } else {
  1456. const bool need_check = true;
  1457. mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
  1458. (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1459. }
  1460. }
  1461. template <ggml_type type>
  1462. void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
  1463. const int id = ggml_cuda_get_device();
  1464. const int nsm = ggml_cuda_info().devices[id].nsm;
  1465. const int cc = ggml_cuda_info().devices[id].cc;
  1466. const int mmq_x_max = get_mmq_x_max_host(cc);
  1467. const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
  1468. const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
  1469. int mmq_x_best = 0;
  1470. int nwaves_best = INT_MAX;
  1471. for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
  1472. const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
  1473. const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
  1474. if (nwaves < nwaves_best) {
  1475. mmq_x_best = mmq_x;
  1476. nwaves_best = nwaves;
  1477. }
  1478. }
  1479. switch (mmq_x_best) {
  1480. case 8:
  1481. launch_mul_mat_q<type, 8, 4>(args, stream);
  1482. break;
  1483. case 16:
  1484. launch_mul_mat_q<type, 16, 4>(args, stream);
  1485. break;
  1486. case 24:
  1487. launch_mul_mat_q<type, 24, 4>(args, stream);
  1488. break;
  1489. case 32:
  1490. launch_mul_mat_q<type, 32, 8>(args, stream);
  1491. break;
  1492. case 40:
  1493. launch_mul_mat_q<type, 40, 8>(args, stream);
  1494. break;
  1495. case 48:
  1496. launch_mul_mat_q<type, 48, 8>(args, stream);
  1497. break;
  1498. case 56:
  1499. launch_mul_mat_q<type, 56, 8>(args, stream);
  1500. break;
  1501. case 64:
  1502. launch_mul_mat_q<type, 64, 8>(args, stream);
  1503. break;
  1504. case 72:
  1505. launch_mul_mat_q<type, 72, 8>(args, stream);
  1506. break;
  1507. case 80:
  1508. launch_mul_mat_q<type, 80, 8>(args, stream);
  1509. break;
  1510. case 88:
  1511. launch_mul_mat_q<type, 88, 8>(args, stream);
  1512. break;
  1513. case 96:
  1514. launch_mul_mat_q<type, 96, 8>(args, stream);
  1515. break;
  1516. case 104:
  1517. launch_mul_mat_q<type, 104, 8>(args, stream);
  1518. break;
  1519. case 112:
  1520. launch_mul_mat_q<type, 112, 8>(args, stream);
  1521. break;
  1522. case 120:
  1523. launch_mul_mat_q<type, 120, 8>(args, stream);
  1524. break;
  1525. case 128:
  1526. launch_mul_mat_q<type, 128, 8>(args, stream);
  1527. break;
  1528. default:
  1529. GGML_ASSERT(false);
  1530. break;
  1531. }
  1532. }
  1533. #define DECL_MMQ_CASE(type) \
  1534. template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
  1535. extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
  1536. extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
  1537. extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
  1538. extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
  1539. extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
  1540. extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
  1541. extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
  1542. extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
  1543. extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
  1544. extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
  1545. // -------------------------------------------------------------------------------------------------------------------------
  1546. void ggml_cuda_op_mul_mat_q(
  1547. ggml_backend_cuda_context & ctx,
  1548. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  1549. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  1550. const int64_t src1_padded_row_size, cudaStream_t stream);
  1551. bool ggml_cuda_supports_mmq(enum ggml_type type);