mmq.cuh 83 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317
  1. #pragma once
  2. #include "common.cuh"
  3. #include "vecdotq.cuh"
  4. #include "mma.cuh"
  5. #include <climits>
  6. #include <cstdint>
  7. #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
  8. #define MMQ_NWARPS 8
  9. typedef void (*load_tiles_mmq_t)(
  10. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  11. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
  12. typedef void (*vec_dot_mmq_t)(
  13. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  14. const int * __restrict__ y, float * __restrict__ sum, const int & k0);
  15. typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
  16. struct block_q8_1_mmq {
  17. half2 ds[4];
  18. int8_t qs[4*QK8_1];
  19. };
  20. static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
  21. static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
  22. struct tile_x_sizes {
  23. int qs;
  24. int dm;
  25. int sc;
  26. };
  27. // get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  28. static constexpr __device__ int get_mmq_x_max_device() {
  29. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  30. return 64;
  31. #else
  32. #if __CUDA_ARCH__ >= CC_VOLTA
  33. #ifdef CUDA_USE_TENSOR_CORES
  34. return MMQ_MAX_BATCH_SIZE;
  35. #else
  36. return 128;
  37. #endif // CUDA_USE_TENSOR_CORES
  38. #else
  39. return 64;
  40. #endif // __CUDA_ARCH__ >= CC_VOLTA
  41. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  42. }
  43. // get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  44. static constexpr __device__ int get_mmq_y_device() {
  45. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  46. return 128;
  47. #else
  48. #if __CUDA_ARCH__ >= CC_VOLTA
  49. return 128;
  50. #else
  51. return 64;
  52. #endif // __CUDA_ARCH__ >= CC_VOLTA
  53. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  54. }
  55. #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
  56. #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0}
  57. #define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0}
  58. #define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0}
  59. #define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0}
  60. #define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0}
  61. #define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/4 + mmq_y/4}
  62. #define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  63. #define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  64. #define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
  65. #define GET_TILE_X_SIZES_BODY \
  66. return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
  67. type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \
  68. type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \
  69. type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \
  70. type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \
  71. type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \
  72. type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \
  73. type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
  74. type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
  75. type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
  76. tile_x_sizes{0, 0, 0}
  77. static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
  78. GET_TILE_X_SIZES_BODY;
  79. }
  80. template <int mmq_y>
  81. static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
  82. GET_TILE_X_SIZES_BODY;
  83. }
  84. // ------------------------------------------------------------
  85. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
  86. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  87. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  88. GGML_UNUSED(x_sc);
  89. const int kbx = threadIdx.x / QI4_0;
  90. const int kqsx = threadIdx.x % QI4_0;
  91. float * x_dmf = (float *) x_dm;
  92. #pragma unroll
  93. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  94. int i = i0 + threadIdx.y;
  95. if (need_check) {
  96. i = min(i, i_max);
  97. }
  98. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
  99. x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  100. }
  101. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  102. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  103. #pragma unroll
  104. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  105. int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
  106. if (need_check) {
  107. i = min(i, i_max);
  108. }
  109. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
  110. x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
  111. }
  112. }
  113. template <int mmq_x, int mmq_y, int nwarps>
  114. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
  115. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  116. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  117. GGML_UNUSED(x_sc);
  118. const float * x_df = (const float *) x_dm;
  119. const int * y_qs = (const int *) y + 4;
  120. const half2 * y_ds = (const half2 *) y;
  121. #pragma unroll
  122. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  123. const int j = j0 + threadIdx.y;
  124. #pragma unroll
  125. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  126. const int i = i0 + threadIdx.x;
  127. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  128. int u[2*VDR_Q4_0_Q8_1_MMQ];
  129. #pragma unroll
  130. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  131. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  132. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
  133. }
  134. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  135. (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
  136. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  137. }
  138. }
  139. }
  140. template <int mmq_x, int mmq_y, int nwarps>
  141. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
  142. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  143. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  144. #ifdef INT8_MMA_AVAILABLE
  145. GGML_UNUSED(x_sc);
  146. typedef mma_int_A_I16K8 mma_A;
  147. typedef mma_int_B_J8K8 mma_B;
  148. typedef mma_int_C_I16J8 mma_C;
  149. const float * x_df = (const float *) x_dm;
  150. const int * y_qs = (const int *) y + 4;
  151. const half2 * y_ds = (const half2 *) y;
  152. mma_A A;
  153. float dA[mma_C::ne/2];
  154. const int i0 = threadIdx.y*mma_A::I;
  155. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  156. #pragma unroll
  157. for (int l = 0; l < mma_A::ne; ++l) {
  158. const int i = i0 + mma_A::get_i(l);
  159. const int k = k0 + mma_A::get_k(l) % QI4_0;
  160. const int shift = 4*(mma_A::get_k(l) / QI4_0);
  161. A.x[l] = __vsubss4((x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
  162. }
  163. #pragma unroll
  164. for (int l = 0; l < mma_C::ne/2; ++l) {
  165. const int i = i0 + mma_C::get_i(2*l);
  166. dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
  167. }
  168. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  169. mma_C C;
  170. mma_B B;
  171. half2 dsB[mma_C::ne/2];
  172. #pragma unroll
  173. for (int l = 0; l < mma_B::ne; ++l) {
  174. const int j = j0 + mma_B::get_j(l);
  175. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  176. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  177. }
  178. #pragma unroll
  179. for (int l = 0; l < mma_C::ne/2; ++l) {
  180. const int j = j0 + mma_C::get_j(l);
  181. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  182. }
  183. C.mma_K8(A, B);
  184. #pragma unroll
  185. for (int l = 0; l < mma_C::ne; ++l) {
  186. sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
  187. }
  188. }
  189. #else
  190. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  191. NO_DEVICE_CODE;
  192. #endif // INT8_MMA_AVAILABLE
  193. }
  194. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
  195. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  196. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  197. GGML_UNUSED(x_sc);
  198. const int kbx = threadIdx.x / QI4_1;
  199. const int kqsx = threadIdx.x % QI4_1;
  200. #pragma unroll
  201. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  202. int i = i0 + threadIdx.y;
  203. if (need_check) {
  204. i = min(i, i_max);
  205. }
  206. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
  207. x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  208. }
  209. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  210. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  211. #pragma unroll
  212. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  213. int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
  214. if (need_check) {
  215. i = min(i, i_max);
  216. }
  217. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
  218. x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
  219. }
  220. }
  221. template <int mmq_x, int mmq_y, int nwarps>
  222. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
  223. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  224. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  225. GGML_UNUSED(x_sc);
  226. const int * y_qs = (const int *) y + 4;
  227. const half2 * y_ds = (const half2 *) y;
  228. #pragma unroll
  229. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  230. const int j = j0 + threadIdx.y;
  231. #pragma unroll
  232. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  233. const int i = i0 + threadIdx.x;
  234. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  235. int u[2*VDR_Q4_1_Q8_1_MMQ];
  236. #pragma unroll
  237. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  238. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  239. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
  240. }
  241. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  242. (&x_qs[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
  243. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  244. }
  245. }
  246. }
  247. template <int mmq_x, int mmq_y, int nwarps>
  248. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
  249. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  250. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  251. #ifdef INT8_MMA_AVAILABLE
  252. GGML_UNUSED(x_sc);
  253. typedef mma_int_A_I16K8 mma_A;
  254. typedef mma_int_B_J8K8 mma_B;
  255. typedef mma_int_C_I16J8 mma_C;
  256. const int * y_qs = (const int *) y + 4;
  257. const half2 * y_ds = (const half2 *) y;
  258. mma_A A;
  259. half2 dmA[mma_C::ne/2];
  260. const int i0 = threadIdx.y*mma_A::I;
  261. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  262. #pragma unroll
  263. for (int l = 0; l < mma_A::ne; ++l) {
  264. const int i = i0 + mma_A::get_i(l);
  265. const int k = k0 + mma_A::get_k(l) % QI4_0;
  266. const int shift = 4*(mma_A::get_k(l) / QI4_0);
  267. A.x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
  268. }
  269. #pragma unroll
  270. for (int l = 0; l < mma_C::ne/2; ++l) {
  271. const int i = i0 + mma_C::get_i(2*l);
  272. dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
  273. }
  274. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  275. mma_C C;
  276. mma_B B;
  277. half2 dsB[mma_C::ne/2];
  278. #pragma unroll
  279. for (int l = 0; l < mma_B::ne; ++l) {
  280. const int j = j0 + mma_B::get_j(l);
  281. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  282. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  283. }
  284. #pragma unroll
  285. for (int l = 0; l < mma_C::ne/2; ++l) {
  286. const int j = j0 + mma_C::get_j(l);
  287. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  288. }
  289. C.mma_K8(A, B);
  290. #pragma unroll
  291. for (int l = 0; l < mma_C::ne; ++l) {
  292. const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
  293. sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
  294. }
  295. }
  296. #else
  297. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  298. NO_DEVICE_CODE;
  299. #endif // INT8_MMA_AVAILABLE
  300. }
  301. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
  302. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  303. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  304. GGML_UNUSED(x_sc);
  305. const int kbx = threadIdx.x / QI5_0;
  306. const int kqsx = threadIdx.x % QI5_0;
  307. #pragma unroll
  308. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  309. int i = i0 + threadIdx.y;
  310. if (need_check) {
  311. i = min(i, i_max);
  312. }
  313. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
  314. const int ql = get_int_from_uint8(bxi->qs, kqsx);
  315. const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
  316. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  317. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  318. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  319. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  320. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  321. qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
  322. x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  323. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  324. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  325. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  326. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  327. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  328. qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
  329. x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  330. }
  331. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  332. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  333. float * x_dmf = (float *) x_dm;
  334. #pragma unroll
  335. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  336. int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
  337. if (need_check) {
  338. i = min(i, i_max);
  339. }
  340. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
  341. x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
  342. }
  343. }
  344. template <int mmq_x, int mmq_y, int nwarps>
  345. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
  346. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  347. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  348. GGML_UNUSED(x_sc);
  349. const float * x_dmf = (const float *) x_dm;
  350. const int * y_qs = (const int *) y + 4;
  351. const float * y_df = (const float *) y;
  352. #pragma unroll
  353. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  354. const int j = j0 + threadIdx.y;
  355. #pragma unroll
  356. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  357. const int i = i0 + threadIdx.x;
  358. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  359. const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
  360. int u[2*VDR_Q5_0_Q8_1_MMQ];
  361. #pragma unroll
  362. for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
  363. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  364. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
  365. }
  366. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
  367. (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  368. }
  369. }
  370. }
  371. template <int mmq_x, int mmq_y, int nwarps>
  372. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
  373. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  374. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  375. #ifdef INT8_MMA_AVAILABLE
  376. GGML_UNUSED(x_sc);
  377. typedef mma_int_A_I16K8 mma_A;
  378. typedef mma_int_B_J8K8 mma_B;
  379. typedef mma_int_C_I16J8 mma_C;
  380. const float * x_df = (const float *) x_dm;
  381. const int * y_qs = (const int *) y + 4;
  382. const float * y_df = (const float *) y;
  383. mma_A A;
  384. float dA[mma_C::ne/2];
  385. const int i0 = threadIdx.y*mma_A::I;
  386. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  387. #pragma unroll
  388. for (int l = 0; l < mma_A::ne; ++l) {
  389. const int i = i0 + mma_A::get_i(l);
  390. const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
  391. A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
  392. }
  393. #pragma unroll
  394. for (int l = 0; l < mma_C::ne/2; ++l) {
  395. const int i = i0 + mma_C::get_i(2*l);
  396. dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
  397. }
  398. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  399. mma_C C;
  400. mma_B B;
  401. float dB[mma_C::ne/2];
  402. #pragma unroll
  403. for (int l = 0; l < mma_B::ne; ++l) {
  404. const int j = j0 + mma_B::get_j(l);
  405. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  406. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  407. }
  408. #pragma unroll
  409. for (int l = 0; l < mma_C::ne/2; ++l) {
  410. const int j = j0 + mma_C::get_j(l);
  411. dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  412. }
  413. C.mma_K8(A, B);
  414. #pragma unroll
  415. for (int l = 0; l < mma_C::ne; ++l) {
  416. sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
  417. }
  418. }
  419. #else
  420. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  421. NO_DEVICE_CODE;
  422. #endif // INT8_MMA_AVAILABLE
  423. }
  424. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
  425. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  426. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  427. GGML_UNUSED(x_sc);
  428. const int kbx = threadIdx.x / QI5_1;
  429. const int kqsx = threadIdx.x % QI5_1;
  430. #pragma unroll
  431. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  432. int i = i0 + threadIdx.y;
  433. if (need_check) {
  434. i = min(i, i_max);
  435. }
  436. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
  437. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  438. const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
  439. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  440. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  441. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  442. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  443. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  444. x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  445. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  446. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  447. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  448. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  449. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  450. x_qs[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  451. }
  452. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  453. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  454. #pragma unroll
  455. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  456. int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
  457. if (need_check) {
  458. i = min(i, i_max);
  459. }
  460. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
  461. x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
  462. }
  463. }
  464. template <int mmq_x, int mmq_y, int nwarps>
  465. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
  466. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  467. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  468. GGML_UNUSED(x_sc);
  469. const int * y_qs = (const int *) y + 4;
  470. const half2 * y_ds = (const half2 *) y;
  471. #pragma unroll
  472. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  473. const int j = j0 + threadIdx.y;
  474. #pragma unroll
  475. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  476. const int i = i0 + threadIdx.x;
  477. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  478. const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
  479. int u[2*VDR_Q5_1_Q8_1_MMQ];
  480. #pragma unroll
  481. for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
  482. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  483. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
  484. }
  485. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  486. (&x_qs[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  487. }
  488. }
  489. }
  490. template <int mmq_x, int mmq_y, int nwarps>
  491. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
  492. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  493. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  494. #ifdef INT8_MMA_AVAILABLE
  495. GGML_UNUSED(x_sc);
  496. typedef mma_int_A_I16K8 mma_A;
  497. typedef mma_int_B_J8K8 mma_B;
  498. typedef mma_int_C_I16J8 mma_C;
  499. const int * y_qs = (const int *) y + 4;
  500. const half2 * y_ds = (const half2 *) y;
  501. mma_A A;
  502. half2 dmA[mma_C::ne/2];
  503. const int i0 = threadIdx.y*mma_A::I;
  504. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  505. #pragma unroll
  506. for (int l = 0; l < mma_A::ne; ++l) {
  507. const int i = i0 + mma_A::get_i(l);
  508. const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
  509. A.x[l] = x_qs[i*(2*WARP_SIZE + 1) + k];
  510. }
  511. #pragma unroll
  512. for (int l = 0; l < mma_C::ne/2; ++l) {
  513. const int i = i0 + mma_C::get_i(2*l);
  514. dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
  515. }
  516. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  517. mma_C C;
  518. mma_B B;
  519. half2 dsB[mma_C::ne/2];
  520. #pragma unroll
  521. for (int l = 0; l < mma_B::ne; ++l) {
  522. const int j = j0 + mma_B::get_j(l);
  523. const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
  524. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  525. }
  526. #pragma unroll
  527. for (int l = 0; l < mma_C::ne/2; ++l) {
  528. const int j = j0 + mma_C::get_j(l);
  529. dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
  530. }
  531. C.mma_K8(A, B);
  532. #pragma unroll
  533. for (int l = 0; l < mma_C::ne; ++l) {
  534. const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
  535. sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
  536. }
  537. }
  538. #else
  539. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  540. NO_DEVICE_CODE;
  541. #endif // INT8_MMA_AVAILABLE
  542. }
  543. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
  544. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  545. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  546. GGML_UNUSED(x_sc);
  547. const int kbx = threadIdx.x / QI8_0;
  548. const int kqsx = threadIdx.x % QI8_0;
  549. float * x_dmf = (float *) x_dm;
  550. #pragma unroll
  551. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  552. int i = i0 + threadIdx.y;
  553. if (need_check) {
  554. i = min(i, i_max);
  555. }
  556. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
  557. x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
  558. }
  559. const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
  560. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  561. #pragma unroll
  562. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
  563. int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
  564. if (need_check) {
  565. i = min(i, i_max);
  566. }
  567. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
  568. x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
  569. }
  570. }
  571. template <int mmq_x, int mmq_y, int nwarps>
  572. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
  573. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  574. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  575. GGML_UNUSED(x_sc);
  576. const float * x_dmf = (const float *) x_dm;
  577. const int * y_qs = (const int *) y + 4;
  578. const float * y_df = (const float *) y;
  579. #pragma unroll
  580. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  581. const int j = j0 + threadIdx.y;
  582. #pragma unroll
  583. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  584. const int i = i0 + threadIdx.x;
  585. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
  586. (&x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
  587. y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
  588. }
  589. }
  590. }
  591. template <int mmq_x, int mmq_y, int nwarps>
  592. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
  593. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  594. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  595. #ifdef INT8_MMA_AVAILABLE
  596. GGML_UNUSED(x_sc);
  597. typedef mma_int_A_I16K8 mma_A;
  598. typedef mma_int_B_J8K8 mma_B;
  599. typedef mma_int_C_I16J8 mma_C;
  600. const float * x_df = (const float *) x_dm;
  601. const int * y_qs = (const int *) y + 4;
  602. const float * y_df = (const float *) y;
  603. mma_A A;
  604. float dA[mma_C::ne/2];
  605. const int i0 = threadIdx.y*mma_A::I;
  606. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  607. #pragma unroll
  608. for (int l = 0; l < mma_A::ne; ++l) {
  609. const int i = i0 + mma_A::get_i(l);
  610. const int k = k0 + mma_A::get_k(l);
  611. A.x[l] = x_qs[i*(WARP_SIZE + 1) + k];
  612. }
  613. #pragma unroll
  614. for (int l = 0; l < mma_C::ne/2; ++l) {
  615. const int i = i0 + mma_C::get_i(2*l);
  616. dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
  617. }
  618. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  619. mma_C C;
  620. mma_B B;
  621. float dB[mma_C::ne/2];
  622. #pragma unroll
  623. for (int l = 0; l < mma_B::ne; ++l) {
  624. const int j = j0 + mma_B::get_j(l);
  625. const int k = k0 + mma_B::get_k(l);
  626. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  627. }
  628. #pragma unroll
  629. for (int l = 0; l < mma_C::ne/2; ++l) {
  630. const int j = j0 + mma_C::get_j(l);
  631. dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
  632. }
  633. C.mma_K8(A, B);
  634. #pragma unroll
  635. for (int l = 0; l < mma_C::ne; ++l) {
  636. sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
  637. }
  638. }
  639. #else
  640. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  641. NO_DEVICE_CODE;
  642. #endif // INT8_MMA_AVAILABLE
  643. }
  644. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
  645. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  646. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  647. const int kbx = threadIdx.x / QI2_K;
  648. const int kqsx = threadIdx.x % QI2_K;
  649. #pragma unroll
  650. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  651. int i = i0 + threadIdx.y;
  652. if (need_check) {
  653. i = min(i, i_max);
  654. }
  655. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
  656. const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
  657. #pragma unroll
  658. for (int l = 0; l < QR2_K; ++l) {
  659. const int k = kbx*QI2_K + (kqsx/8)*8 + l*2 + (kqsx % 8)/4;
  660. int x_qs_k = ((x_ql_0 >> (2*l)) & 0x03030303) << (2*(kqsx % 4));
  661. x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
  662. x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 2, WARP_SIZE);
  663. if (kqsx % QR2_K != 0) {
  664. continue;
  665. }
  666. x_qs[i*(WARP_SIZE + 1) + k] = x_qs_k;
  667. }
  668. const int sc_m = bxi->scales[kqsx];
  669. #ifdef FAST_FP16_AVAILABLE
  670. const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
  671. #else
  672. const float2 bxi_dmf = __half22float2(bxi->dm);
  673. const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
  674. #endif // FAST_FP16_AVAILABLE
  675. x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
  676. }
  677. }
  678. template <int mmq_x, int mmq_y, int nwarps>
  679. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
  680. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  681. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  682. const int * y_qs = (const int *) y + 4;
  683. const float * y_df = (const float *) y;
  684. #pragma unroll
  685. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  686. const int j = j0 + threadIdx.y;
  687. #pragma unroll
  688. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  689. const int i = i0 + threadIdx.x;
  690. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
  691. &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE],
  692. &x_dm[i*(WARP_SIZE + 1) + k0], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
  693. }
  694. }
  695. }
  696. template <int mmq_x, int mmq_y, int nwarps>
  697. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
  698. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  699. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  700. #ifdef INT8_MMA_AVAILABLE
  701. typedef mma_int_A_I16K4 mma_A;
  702. typedef mma_int_B_J8K4 mma_B;
  703. typedef mma_int_C_I16J8 mma_C;
  704. const int * y_qs = (const int *) y + 4;
  705. const float * y_df = (const float *) y;
  706. const int i0 = threadIdx.y*mma_A::I;
  707. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  708. mma_A A[2];
  709. float dA[mma_C::ne/2][2];
  710. float mA[mma_C::ne/2][2];
  711. #pragma unroll
  712. for (int l = 0; l < mma_A::ne; ++l) {
  713. const int i = i0 + mma_A::get_i(l);
  714. const int shift = 2*mma_A::get_k(l);
  715. A[0].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 0] >> shift) & 0x03030303;
  716. A[1].x[l] = (x_qs[i*(WARP_SIZE + 1) + k0 + 1] >> shift) & 0x03030303;
  717. }
  718. #pragma unroll
  719. for (int l = 0; l < mma_C::ne/2; ++l) {
  720. const int i = i0 + mma_C::get_i(2*l);
  721. #pragma unroll
  722. for (int kk = 0; kk < 2; ++kk) {
  723. const float2 dm = __half22float2(x_dm[i*(WARP_SIZE + 1) + k0 + kk]);
  724. dA[l][kk] = dm.x;
  725. mA[l][kk] = dm.y;
  726. }
  727. }
  728. #pragma unroll
  729. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  730. mma_C Cd[2];
  731. mma_C Cm[2];
  732. mma_B B[2];
  733. float dB[mma_C::ne/2];
  734. #pragma unroll
  735. for (int l = 0; l < mma_B::ne; ++l) {
  736. const int j = j0 + mma_B::get_j(l);
  737. const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
  738. B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
  739. B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
  740. }
  741. #pragma unroll
  742. for (int l = 0; l < mma_C::ne/2; ++l) {
  743. const int j = j0 + mma_C::get_j(l);
  744. dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
  745. }
  746. Cd[0].mma_K4(A[0], B[0]);
  747. Cd[1].mma_K4(A[1], B[1]);
  748. mma_A A1;
  749. A1.x[0] = 0x01010101;
  750. A1.x[1] = 0x01010101;
  751. Cm[0].mma_K4(A1, B[0]);
  752. Cm[1].mma_K4(A1, B[1]);
  753. #pragma unroll
  754. for (int l = 0; l < mma_C::ne; ++l) {
  755. sum[(j0/mma_B::J)*mma_C::ne + l] += (Cd[0].x[l]*dA[l/2][0] + Cd[1].x[l]*dA[l/2][1] - Cm[0].x[l]*mA[l/2][0] - Cm[1].x[l]*mA[l/2][1])*dB[l%2];
  756. }
  757. }
  758. #else
  759. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  760. NO_DEVICE_CODE;
  761. #endif // INT8_MMA_AVAILABLE
  762. }
  763. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
  764. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  765. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  766. const int kbx = threadIdx.x / QI3_K;
  767. const int kqsx = threadIdx.x % QI3_K;
  768. #pragma unroll
  769. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  770. int i = i0 + threadIdx.y;
  771. if (need_check) {
  772. i = min(i, i_max);
  773. }
  774. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
  775. const int x_ql_0 = get_int_from_uint8(bxi->qs, kqsx);
  776. const int x_qh_0 = get_int_from_uint8(bxi->hmask, kqsx % (QI3_K/2)) >> (4 * (kqsx / (QI3_K/2)));
  777. #pragma unroll
  778. for (int l = 0; l < QR3_K; ++l) {
  779. const int k = kbx*(QR3_K*QI3_K) + (kqsx/8)*32 + l*8 + kqsx % 8;
  780. const int x_ql_k = (x_ql_0 >> (2*l)) & 0x03030303;
  781. const int x_qh_k = ((x_qh_0 >> l) << 2) & 0x04040404;
  782. int x_qs_k = (x_ql_k | x_qh_k) << (4*(k%2));
  783. x_qs_k |= __shfl_xor_sync(0xFFFFFFFF, x_qs_k, 1, WARP_SIZE);
  784. if (kqsx % 2 != 0) {
  785. continue;
  786. }
  787. x_qs[i*(2*WARP_SIZE + 1) + k/2] = x_qs_k;
  788. }
  789. }
  790. const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
  791. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  792. float * x_dmf = (float *) x_dm;
  793. #pragma unroll
  794. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
  795. int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  796. if (need_check) {
  797. i = min(i, i_max);
  798. }
  799. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
  800. x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
  801. }
  802. #pragma unroll
  803. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  804. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  805. if (need_check) {
  806. i = min(i, i_max);
  807. }
  808. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
  809. const int ksc = threadIdx.x % (QI3_K/4);
  810. const int ksc_low = ksc % (QI3_K/8);
  811. const int shift_low = 4 * (ksc / (QI3_K/8));
  812. const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  813. const int ksc_high = QI3_K/8;
  814. const int shift_high = 2 * ksc;
  815. const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  816. const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
  817. x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
  818. }
  819. }
  820. template <int mmq_x, int mmq_y, int nwarps>
  821. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
  822. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  823. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  824. const float * x_df = (const float *) x_dm;
  825. const int * y_qs = (const int *) y + 4;
  826. const float * y_df = (const float *) y;
  827. #pragma unroll
  828. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  829. const int j = j0 + threadIdx.y;
  830. #pragma unroll
  831. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  832. const int i = i0 + threadIdx.x;
  833. const int kbx = k0 / QI3_K;
  834. const int ky = (k0 % QI3_K) * QR3_K;
  835. const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  836. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
  837. &x_qs[i*(2*WARP_SIZE + 1) + 2*k0], &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
  838. x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
  839. }
  840. }
  841. }
  842. template <int mmq_x, int mmq_y, int nwarps>
  843. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma(
  844. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  845. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  846. #ifdef INT8_MMA_AVAILABLE
  847. typedef mma_int_A_I16K4 mma_A;
  848. typedef mma_int_B_J8K4 mma_B;
  849. typedef mma_int_C_I16J8 mma_C;
  850. const float * x_df = (const float *) x_dm;
  851. const int * y_qs = (const int *) y + 4;
  852. const float * y_df = (const float *) y;
  853. const int i0 = threadIdx.y*mma_A::I;
  854. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  855. mma_A A[2];
  856. int scA[mma_C::ne/2][2];
  857. float dA[mma_C::ne/2];
  858. #pragma unroll
  859. for (int l = 0; l < mma_A::ne; ++l) {
  860. const int i = i0 + mma_A::get_i(l);
  861. const int k = QR3_K*k0 + mma_A::get_k(l);
  862. A[0].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + 0] >> (4*(k%2))) & 0x0F0F0F0F;
  863. A[1].x[l] = (x_qs[i*(2*WARP_SIZE + 1) + k/2 + mma_A::K/2] >> (4*(k%2))) & 0x0F0F0F0F;
  864. A[0].x[l] = __vsubss4(A[0].x[l], 0x04040404);
  865. A[1].x[l] = __vsubss4(A[1].x[l], 0x04040404);
  866. }
  867. #pragma unroll
  868. for (int l = 0; l < mma_C::ne/2; ++l) {
  869. const int i = i0 + mma_C::get_i(2*l);
  870. const int kbx = k0 / QI3_K;
  871. const int ky = (k0 % QI3_K) * QR3_K;
  872. const int8_t * sc = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  873. scA[l][0] = sc[0];
  874. scA[l][1] = sc[1];
  875. }
  876. #pragma unroll
  877. for (int l = 0; l < mma_C::ne/2; ++l) {
  878. const int i = i0 + mma_C::get_i(2*l);
  879. dA[l] = x_df[i*(WARP_SIZE/QI3_K) + i/QI3_K + k0/QI3_K];
  880. }
  881. #pragma unroll
  882. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  883. mma_C C[2];
  884. mma_B B[2];
  885. float dB[mma_C::ne/2];
  886. #pragma unroll
  887. for (int l = 0; l < mma_B::ne; ++l) {
  888. const int j = j0 + mma_B::get_j(l);
  889. const int k = (4*k0 + mma_B::get_k(l)) % WARP_SIZE;
  890. B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
  891. B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
  892. }
  893. #pragma unroll
  894. for (int l = 0; l < mma_C::ne/2; ++l) {
  895. const int j = j0 + mma_C::get_j(l);
  896. dB[l] = y_df[j*MMQ_TILE_Y_K + ((4*k0)/QI8_1) % (WARP_SIZE/QI8_1)];
  897. }
  898. C[0].mma_K4(A[0], B[0]);
  899. C[1].mma_K4(A[1], B[1]);
  900. #pragma unroll
  901. for (int l = 0; l < mma_C::ne; ++l) {
  902. sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2];
  903. }
  904. }
  905. #else
  906. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  907. NO_DEVICE_CODE;
  908. #endif // INT8_MMA_AVAILABLE
  909. }
  910. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
  911. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  912. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  913. const int kbx = 0; // threadIdx.x / QI4_K
  914. const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
  915. #pragma unroll
  916. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  917. int i = i0 + threadIdx.y;
  918. if (need_check) {
  919. i = min(i, i_max);
  920. }
  921. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
  922. x_qs[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  923. }
  924. const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
  925. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  926. #pragma unroll
  927. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
  928. int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  929. if (need_check) {
  930. i = min(i, i_max);
  931. }
  932. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
  933. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
  934. }
  935. #pragma unroll
  936. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  937. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  938. if (need_check) {
  939. i = min(i, i_max);
  940. }
  941. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
  942. const int * scales = (const int *) bxi->scales;
  943. const int ksc = threadIdx.x % (WARP_SIZE/8);
  944. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  945. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  946. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  947. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  948. }
  949. }
  950. template <int mmq_x, int mmq_y, int nwarps>
  951. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
  952. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  953. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  954. const int * y_qs = (const int *) y + 4;
  955. const half2 * y_ds = (const half2 *) y;
  956. #pragma unroll
  957. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  958. const int j = j0 + threadIdx.y;
  959. #pragma unroll
  960. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  961. const int i = i0 + threadIdx.x;
  962. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
  963. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
  964. &x_qs[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
  965. x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
  966. }
  967. }
  968. }
  969. template <int mmq_x, int mmq_y, int nwarps>
  970. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma(
  971. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  972. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  973. #ifdef INT8_MMA_AVAILABLE
  974. typedef mma_int_A_I16K8 mma_A;
  975. typedef mma_int_B_J8K8 mma_B;
  976. typedef mma_int_C_I16J8 mma_C;
  977. const int * y_qs = (const int *) y + 4;
  978. const half2 * y_ds = (const half2 *) y;
  979. const int i0 = threadIdx.y*mma_A::I;
  980. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  981. mma_A A[2];
  982. int scA[mma_C::ne/2][2];
  983. int mA[mma_C::ne/2][2];
  984. half2 dmA[mma_C::ne/2];
  985. #pragma unroll
  986. for (int kvdr = 0; kvdr < VDR_Q4_K_Q8_1_MMQ; kvdr += 4) {
  987. #pragma unroll
  988. for (int l = 0; l < mma_A::ne; ++l) {
  989. const int i = i0 + mma_A::get_i(l);
  990. const int k = k0 + mma_A::get_k(l);
  991. A[kvdr/4].x[l] = (x_qs[i*(WARP_SIZE + 1) + k] >> kvdr) & 0x0F0F0F0F;
  992. }
  993. #pragma unroll
  994. for (int l = 0; l < mma_C::ne/2; ++l) {
  995. const int i = i0 + mma_C::get_i(2*l);
  996. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  997. const uint8_t * m = sc + 8;
  998. scA[l][kvdr/4] = sc[kvdr/4];
  999. mA[l][kvdr/4] = m[kvdr/4];
  1000. }
  1001. }
  1002. #pragma unroll
  1003. for (int l = 0; l < mma_C::ne/2; ++l) {
  1004. const int i = i0 + mma_C::get_i(2*l);
  1005. dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
  1006. }
  1007. #pragma unroll
  1008. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  1009. float tmpd[mma_C::ne] = {0.0f};
  1010. float tmpm[mma_C::ne] = {0.0f};
  1011. #pragma unroll
  1012. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  1013. mma_C C;
  1014. mma_B B;
  1015. half2 dsB[mma_C::ne/2];
  1016. #pragma unroll
  1017. for (int l = 0; l < mma_B::ne; ++l) {
  1018. const int j = j0 + mma_B::get_j(l);
  1019. const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
  1020. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  1021. }
  1022. #pragma unroll
  1023. for (int l = 0; l < mma_C::ne/2; ++l) {
  1024. const int j = j0 + mma_C::get_j(l);
  1025. dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1026. }
  1027. C.mma_K8(A[kvdr/4], B);
  1028. #pragma unroll
  1029. for (int l = 0; l < mma_C::ne; ++l) {
  1030. tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
  1031. tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
  1032. }
  1033. }
  1034. #pragma unroll
  1035. for (int l = 0; l < mma_C::ne; ++l) {
  1036. sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
  1037. }
  1038. }
  1039. #else
  1040. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  1041. NO_DEVICE_CODE;
  1042. #endif // INT8_MMA_AVAILABLE
  1043. }
  1044. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
  1045. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  1046. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  1047. const int kbx = 0; // threadIdx.x / QI5_K
  1048. const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
  1049. #pragma unroll
  1050. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1051. int i = i0 + threadIdx.y;
  1052. if (need_check) {
  1053. i = min(i, i_max);
  1054. }
  1055. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
  1056. const int ky = QR5_K*kqsx;
  1057. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  1058. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  1059. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  1060. const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
  1061. const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  1062. const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  1063. const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
  1064. const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
  1065. x_qs[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  1066. x_qs[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  1067. }
  1068. const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
  1069. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  1070. #pragma unroll
  1071. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
  1072. int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  1073. if (need_check) {
  1074. i = min(i, i_max);
  1075. }
  1076. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
  1077. x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
  1078. }
  1079. #pragma unroll
  1080. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1081. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1082. if (need_check) {
  1083. i = min(i, i_max);
  1084. }
  1085. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
  1086. const int * scales = (const int *) bxi->scales;
  1087. const int ksc = threadIdx.x % (WARP_SIZE/8);
  1088. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  1089. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  1090. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  1091. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  1092. }
  1093. }
  1094. template <int mmq_x, int mmq_y, int nwarps>
  1095. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
  1096. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  1097. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1098. const int * y_qs = (const int *) y + 4;
  1099. const half2 * y_ds = (const half2 *) y;
  1100. #pragma unroll
  1101. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1102. const int j = j0 + threadIdx.y;
  1103. #pragma unroll
  1104. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1105. const int i = i0 + threadIdx.x;
  1106. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  1107. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
  1108. &x_qs[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
  1109. x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
  1110. }
  1111. }
  1112. }
  1113. template <int mmq_x, int mmq_y, int nwarps>
  1114. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma(
  1115. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  1116. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1117. #ifdef INT8_MMA_AVAILABLE
  1118. typedef mma_int_A_I16K8 mma_A;
  1119. typedef mma_int_B_J8K8 mma_B;
  1120. typedef mma_int_C_I16J8 mma_C;
  1121. const int * y_qs = (const int *) y + 4;
  1122. const half2 * y_ds = (const half2 *) y;
  1123. const int i0 = threadIdx.y*mma_A::I;
  1124. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  1125. mma_A A[2];
  1126. int scA[mma_C::ne/2][2];
  1127. int mA[mma_C::ne/2][2];
  1128. half2 dmA[mma_C::ne/2];
  1129. #pragma unroll
  1130. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  1131. #pragma unroll
  1132. for (int l = 0; l < mma_A::ne; ++l) {
  1133. const int i = i0 + mma_A::get_i(l);
  1134. const int k = QR5_K*k0 + QR5_K*kvdr + mma_A::get_k(l);
  1135. A[kvdr/4].x[l] = x_qs[i*(QR5_K*WARP_SIZE + 1) + k];
  1136. }
  1137. #pragma unroll
  1138. for (int l = 0; l < mma_C::ne/2; ++l) {
  1139. const int i = i0 + mma_C::get_i(2*l);
  1140. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  1141. const uint8_t * m = sc + 8;
  1142. scA[l][kvdr/4] = sc[kvdr/4];
  1143. mA[l][kvdr/4] = m[kvdr/4];
  1144. }
  1145. }
  1146. #pragma unroll
  1147. for (int l = 0; l < mma_C::ne/2; ++l) {
  1148. const int i = i0 + mma_C::get_i(2*l);
  1149. dmA[l] = x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + k0/QI5_K];
  1150. }
  1151. #pragma unroll
  1152. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  1153. float tmpd[mma_C::ne] = {0.0f};
  1154. float tmpm[mma_C::ne] = {0.0f};
  1155. #pragma unroll
  1156. for (int kvdr = 0; kvdr < VDR_Q5_K_Q8_1_MMQ; kvdr += 4) {
  1157. mma_C C;
  1158. mma_B B;
  1159. half2 dsB[mma_C::ne/2];
  1160. #pragma unroll
  1161. for (int l = 0; l < mma_B::ne; ++l) {
  1162. const int j = j0 + mma_B::get_j(l);
  1163. const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
  1164. B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
  1165. }
  1166. #pragma unroll
  1167. for (int l = 0; l < mma_C::ne/2; ++l) {
  1168. const int j = j0 + mma_C::get_j(l);
  1169. dsB[l] = y_ds[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1170. }
  1171. C.mma_K8(A[kvdr/4], B);
  1172. #pragma unroll
  1173. for (int l = 0; l < mma_C::ne; ++l) {
  1174. tmpd[l] += (C.x[l]*scA[l/2][kvdr/4]) * __low2float(dsB[l%2]);
  1175. tmpm[l] += mA[l/2][kvdr/4] * __high2float(dsB[l%2]);
  1176. }
  1177. }
  1178. #pragma unroll
  1179. for (int l = 0; l < mma_C::ne; ++l) {
  1180. sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l];
  1181. }
  1182. }
  1183. #else
  1184. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  1185. NO_DEVICE_CODE;
  1186. #endif // INT8_MMA_AVAILABLE
  1187. }
  1188. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
  1189. const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm,
  1190. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  1191. const int kbx = 0; // threadIdx.x / QI6_K
  1192. const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
  1193. #pragma unroll
  1194. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  1195. int i = i0 + threadIdx.y;
  1196. if (need_check) {
  1197. i = min(i, i_max);
  1198. }
  1199. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
  1200. const int ky = QR6_K*kqsx;
  1201. const int ql = get_int_from_uint8(bxi->ql, kqsx);
  1202. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  1203. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  1204. const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
  1205. const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
  1206. const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
  1207. const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
  1208. const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
  1209. x_qs[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  1210. x_qs[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  1211. }
  1212. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  1213. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  1214. float * x_dmf = (float *) x_dm;
  1215. #pragma unroll
  1216. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  1217. int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  1218. if (need_check) {
  1219. i = min(i, i_max);
  1220. }
  1221. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
  1222. x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
  1223. }
  1224. #pragma unroll
  1225. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  1226. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  1227. if (need_check) {
  1228. i = min(i, i_max);
  1229. }
  1230. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
  1231. x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
  1232. }
  1233. }
  1234. template <int mmq_x, int mmq_y, int nwarps>
  1235. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
  1236. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  1237. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1238. const float * x_dmf = (const float *) x_dm;
  1239. const int * y_qs = (const int *) y + 4;
  1240. const float * y_df = (const float *) y;
  1241. #pragma unroll
  1242. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1243. const int j = j0 + threadIdx.y;
  1244. #pragma unroll
  1245. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1246. const int i = i0 + threadIdx.x;
  1247. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
  1248. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
  1249. &x_qs[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
  1250. x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
  1251. }
  1252. }
  1253. }
  1254. template <int mmq_x, int mmq_y, int nwarps>
  1255. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
  1256. const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc,
  1257. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  1258. #ifdef INT8_MMA_AVAILABLE
  1259. typedef mma_int_A_I16K4 mma_A;
  1260. typedef mma_int_B_J8K4 mma_B;
  1261. typedef mma_int_C_I16J8 mma_C;
  1262. const float * x_df = (const float *) x_dm;
  1263. const int * y_qs = (const int *) y + 4;
  1264. const float * y_df = (const float *) y;
  1265. const int i0 = threadIdx.y*mma_A::I;
  1266. #ifdef INT8_MMA_AVAILABLE
  1267. static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
  1268. #endif // INT8_MMA_AVAILABLE
  1269. mma_A A[4];
  1270. int scA[mma_C::ne/2][4];
  1271. float dA[mma_C::ne/2];
  1272. #pragma unroll
  1273. for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
  1274. #pragma unroll
  1275. for (int l = 0; l < mma_A::ne; ++l) {
  1276. const int i = i0 + mma_A::get_i(l);
  1277. const int k = QR6_K*k0 + QR6_K*kvdr + mma_A::get_k(l);
  1278. A[kvdr/2 + 0].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + 0];
  1279. A[kvdr/2 + 1].x[l] = x_qs[i*(QR6_K*WARP_SIZE + 1) + k + mma_A::K];
  1280. }
  1281. #pragma unroll
  1282. for (int l = 0; l < mma_C::ne/2; ++l) {
  1283. const int i = i0 + mma_C::get_i(2*l);
  1284. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
  1285. scA[l][kvdr/2 + 0] = sc[kvdr/2 + 0];
  1286. scA[l][kvdr/2 + 1] = sc[kvdr/2 + 1];
  1287. }
  1288. }
  1289. #pragma unroll
  1290. for (int l = 0; l < mma_C::ne/2; ++l) {
  1291. const int i = i0 + mma_C::get_i(2*l);
  1292. dA[l] = x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + k0/QI6_K];
  1293. }
  1294. #pragma unroll
  1295. for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
  1296. float tmp[mma_C::ne] = {0.0f};
  1297. #pragma unroll
  1298. for (int kvdr = 0; kvdr < VDR_Q6_K_Q8_1_MMQ; kvdr += 4) {
  1299. mma_C C[2];
  1300. mma_B B[2];
  1301. float dB[mma_C::ne/2];
  1302. #pragma unroll
  1303. for (int l = 0; l < mma_B::ne; ++l) {
  1304. const int j = j0 + mma_B::get_j(l);
  1305. const int k = (2*k0 + 2*kvdr + mma_B::get_k(l)) % WARP_SIZE;
  1306. B[0].x[l] = y_qs[j*MMQ_TILE_Y_K + k + 0];
  1307. B[1].x[l] = y_qs[j*MMQ_TILE_Y_K + k + mma_B::K];
  1308. }
  1309. #pragma unroll
  1310. for (int l = 0; l < mma_C::ne/2; ++l) {
  1311. const int j = j0 + mma_C::get_j(l);
  1312. dB[l] = y_df[j*MMQ_TILE_Y_K + ((2*k0 + 2*kvdr)/QI8_1) % (WARP_SIZE/QI8_1)];
  1313. }
  1314. C[0].mma_K4(A[kvdr/2 + 0], B[0]);
  1315. C[1].mma_K4(A[kvdr/2 + 1], B[1]);
  1316. #pragma unroll
  1317. for (int l = 0; l < mma_C::ne; ++l) {
  1318. tmp[l] += (C[0].x[l]*scA[l/2][kvdr/2 + 0] + C[1].x[l]*scA[l/2][kvdr/2 + 1])*dB[l%2];
  1319. }
  1320. }
  1321. #pragma unroll
  1322. for (int l = 0; l < mma_C::ne; ++l) {
  1323. sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2];
  1324. }
  1325. }
  1326. #else
  1327. GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0);
  1328. NO_DEVICE_CODE;
  1329. #endif // INT8_MMA_AVAILABLE
  1330. }
  1331. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1332. static __device__ __forceinline__ void mmq_write_back_dp4a(
  1333. const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
  1334. #pragma unroll
  1335. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1336. const int j = j0 + threadIdx.y;
  1337. if (j > j_max) {
  1338. return;
  1339. }
  1340. #pragma unroll
  1341. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1342. const int i = i0 + threadIdx.x;
  1343. if (need_check && i > i_max) {
  1344. continue;
  1345. }
  1346. dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  1347. }
  1348. }
  1349. }
  1350. template<int mmq_x, int mmq_y, int nwarps, bool need_check>
  1351. static __device__ __forceinline__ void mmq_write_back_mma(
  1352. const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
  1353. typedef mma_int_C_I16J8 mma_C;
  1354. const int i0 = threadIdx.y*mma_C::I;
  1355. #ifdef INT8_MMA_AVAILABLE
  1356. static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
  1357. #endif // INT8_MMA_AVAILABLE
  1358. #pragma unroll
  1359. for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
  1360. #pragma unroll
  1361. for (int l = 0; l < mma_C::ne; ++l) {
  1362. const int j = j0 + mma_C::get_j(l);
  1363. if (j > j_max) {
  1364. continue;
  1365. }
  1366. const int i = i0 + mma_C::get_i(l);
  1367. if (need_check && i > i_max) {
  1368. continue;
  1369. }
  1370. dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
  1371. }
  1372. }
  1373. }
  1374. // -------------------------------------------------------------------------------------------------------------------------------------
  1375. template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
  1376. struct mmq_type_traits;
  1377. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1378. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
  1379. static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
  1380. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
  1381. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1382. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1383. };
  1384. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1385. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
  1386. static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
  1387. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
  1388. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1389. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1390. };
  1391. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1392. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
  1393. static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
  1394. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
  1395. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1396. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1397. };
  1398. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1399. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
  1400. static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
  1401. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
  1402. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1403. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1404. };
  1405. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1406. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
  1407. static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
  1408. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
  1409. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1410. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1411. };
  1412. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1413. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
  1414. static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
  1415. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
  1416. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1417. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1418. };
  1419. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1420. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
  1421. static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
  1422. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
  1423. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1424. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1425. };
  1426. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1427. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
  1428. static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
  1429. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
  1430. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1431. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1432. };
  1433. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1434. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
  1435. static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
  1436. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
  1437. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1438. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1439. };
  1440. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  1441. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
  1442. static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
  1443. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
  1444. static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma<mmq_x, mmq_y, nwarps>;
  1445. static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
  1446. };
  1447. static bool mmq_need_sum(const ggml_type type_x) {
  1448. switch (type_x) {
  1449. case GGML_TYPE_Q4_0:
  1450. case GGML_TYPE_Q4_1:
  1451. return true;
  1452. case GGML_TYPE_Q5_0:
  1453. return false;
  1454. case GGML_TYPE_Q5_1:
  1455. return true;
  1456. case GGML_TYPE_Q8_0:
  1457. case GGML_TYPE_Q2_K:
  1458. case GGML_TYPE_Q3_K:
  1459. return false;
  1460. case GGML_TYPE_Q4_K:
  1461. case GGML_TYPE_Q5_K:
  1462. return true;
  1463. case GGML_TYPE_Q6_K:
  1464. return false;
  1465. default:
  1466. GGML_ASSERT(false);
  1467. break;
  1468. }
  1469. return false;
  1470. }
  1471. template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
  1472. static __device__ void mul_mat_q_process_tile(
  1473. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
  1474. const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
  1475. const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
  1476. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  1477. constexpr int qr = ggml_cuda_type_traits<type>::qr;
  1478. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  1479. constexpr int mmq_y = get_mmq_y_device();
  1480. constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
  1481. constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
  1482. #ifdef INT8_MMA_AVAILABLE
  1483. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_mma;
  1484. constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
  1485. #else
  1486. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot_dp4a;
  1487. constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
  1488. #endif // INT8_MMA_AVAILABLE
  1489. constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
  1490. extern __shared__ char data_mul_mat_q[];
  1491. int * tile_x_qs = (int *) data_mul_mat_q;
  1492. half2 * tile_x_dm = (half2 *) (tile_x_qs + txs.qs);
  1493. int * tile_x_sc = (int *) (tile_x_dm + txs.dm);
  1494. int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
  1495. constexpr int blocks_per_warp = WARP_SIZE / qi;
  1496. float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
  1497. const int tile_x_max_i = ne01 - it*mmq_y - 1;
  1498. const int tile_y_max_j = ne11 - jt*mmq_x - 1;
  1499. const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
  1500. for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) {
  1501. load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
  1502. #pragma unroll
  1503. for (int kr = 0; kr < qr; ++kr) {
  1504. const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
  1505. #pragma unroll
  1506. for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
  1507. int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
  1508. tile_y[l] = by0[l];
  1509. }
  1510. __syncthreads();
  1511. // #pragma unroll // unrolling this loop causes too much register pressure
  1512. for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
  1513. vec_dot(tile_x_qs, tile_x_dm, tile_x_sc, tile_y, sum, k0);
  1514. }
  1515. __syncthreads();
  1516. }
  1517. }
  1518. if (fixup) {
  1519. write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
  1520. } else {
  1521. write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
  1522. }
  1523. }
  1524. // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
  1525. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  1526. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1527. #if defined(RDNA3) || defined(RDNA2)
  1528. __launch_bounds__(WARP_SIZE*nwarps, 2)
  1529. #endif // defined(RDNA3) || defined(RDNA2)
  1530. #else
  1531. #if __CUDA_ARCH__ >= CC_VOLTA
  1532. __launch_bounds__(WARP_SIZE*nwarps, 1)
  1533. #else
  1534. __launch_bounds__(WARP_SIZE*nwarps, 2)
  1535. #endif // __CUDA_ARCH__ >= CC_VOLTA
  1536. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1537. static __global__ void mul_mat_q(
  1538. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
  1539. const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
  1540. // Skip unused template specializations for faster compilation:
  1541. if (mmq_x > get_mmq_x_max_device()) {
  1542. NO_DEVICE_CODE;
  1543. return;
  1544. }
  1545. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  1546. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  1547. constexpr int mmq_y = get_mmq_y_device();
  1548. // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
  1549. #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
  1550. {
  1551. constexpr bool fixup = false;
  1552. mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
  1553. (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
  1554. blockIdx.x, blockIdx.y, 0, ne00/qk);
  1555. return;
  1556. }
  1557. #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
  1558. const int64_t blocks_per_ne00 = ne00 / qk;
  1559. constexpr int blocks_per_warp = WARP_SIZE / qi;
  1560. const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
  1561. const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
  1562. // kbc == k block continuous, current index in continuous ijk space.
  1563. int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
  1564. const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp);
  1565. // kb0 == k index when doing the matrix multiplication for an output tile.
  1566. int kb0_start = kbc % blocks_per_ne00;
  1567. int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
  1568. while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
  1569. const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
  1570. const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
  1571. constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
  1572. mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
  1573. (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
  1574. it, jt, kb0_start, kb0_stop);
  1575. kbc += blocks_per_ne00;
  1576. kbc -= kbc % blocks_per_ne00;
  1577. kb0_start = 0;
  1578. kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
  1579. }
  1580. if (kbc >= kbc_stop) {
  1581. return;
  1582. }
  1583. const int jt = kbc / (blocks_per_ne00*nty);
  1584. const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
  1585. constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks.
  1586. mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
  1587. (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
  1588. it, jt, kb0_start, kb0_stop);
  1589. }
  1590. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  1591. static __global__ void mul_mat_q_stream_k_fixup(
  1592. float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
  1593. constexpr int mmq_y = get_mmq_y_device();
  1594. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  1595. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  1596. constexpr int blocks_per_warp = WARP_SIZE / qi;
  1597. const int64_t blocks_per_ne00 = ne00 / qk;
  1598. float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
  1599. const int ntx = (ne11 + mmq_x - 1) / mmq_x;
  1600. const int nty = (ne01 + mmq_y - 1) / mmq_y;
  1601. bool any_fixup = false;
  1602. const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x);
  1603. const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1;
  1604. for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
  1605. const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
  1606. const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp);
  1607. // Skip fixup tile if the MMQ CUDA block never wrote anything to it:
  1608. if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
  1609. continue;
  1610. }
  1611. const int jt = kbc_stop / (blocks_per_ne00*nty);
  1612. const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
  1613. // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
  1614. if (it != blockIdx.x || jt != blockIdx.y) {
  1615. continue;
  1616. }
  1617. any_fixup = true;
  1618. #pragma unroll
  1619. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1620. const int j = j0 + threadIdx.y;
  1621. #pragma unroll
  1622. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1623. const int i = i0 + threadIdx.x;
  1624. sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
  1625. }
  1626. }
  1627. }
  1628. if (!any_fixup) {
  1629. return;
  1630. }
  1631. dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
  1632. const int i_max = ne01 - blockIdx.x*mmq_y - 1;
  1633. const int j_max = ne11 - blockIdx.y*mmq_x - 1;
  1634. #pragma unroll
  1635. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  1636. const int j = j0 + threadIdx.y;
  1637. if (j > j_max) {
  1638. return;
  1639. }
  1640. #pragma unroll
  1641. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  1642. const int i = i0 + threadIdx.x;
  1643. if (need_check && i > i_max) {
  1644. continue;
  1645. }
  1646. dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  1647. }
  1648. }
  1649. }
  1650. struct mmq_args {
  1651. const char * x; const char * y; float * dst;
  1652. int64_t ne00; int64_t ne01; int64_t stride01;
  1653. int64_t ne10; int64_t ne11; int64_t stride11;
  1654. int64_t ne0;
  1655. };
  1656. static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) {
  1657. const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
  1658. const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
  1659. const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
  1660. return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
  1661. }
  1662. template <ggml_type type, int mmq_x>
  1663. static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
  1664. const int id = ggml_cuda_get_device();
  1665. const int cc = ggml_cuda_info().devices[id].cc;
  1666. const int nsm = ggml_cuda_info().devices[id].nsm;
  1667. const int mmq_y = get_mmq_y_host(cc);
  1668. const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
  1669. const int shmem = mmq_get_shmem(type, mmq_x, mmq_y);
  1670. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  1671. static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
  1672. if (!shmem_limit_raised[id]) {
  1673. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  1674. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  1675. shmem_limit_raised[id] = true;
  1676. }
  1677. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  1678. const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
  1679. const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
  1680. const dim3 block_nums_xy_tiling(nty, ntx, 1);
  1681. const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
  1682. if (!use_stream_k) {
  1683. if (args.ne01 % mmq_y == 0) {
  1684. constexpr bool need_check = false;
  1685. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
  1686. (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1687. } else {
  1688. constexpr bool need_check = true;
  1689. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
  1690. (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1691. }
  1692. return;
  1693. }
  1694. const dim3 block_nums_mmq(nsm, 1, 1);
  1695. ggml_cuda_pool & pool = ctx.pool();
  1696. ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
  1697. if (args.ne01 % mmq_y == 0) {
  1698. constexpr bool need_check = false;
  1699. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
  1700. (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1701. mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
  1702. (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
  1703. } else {
  1704. constexpr bool need_check = true;
  1705. mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
  1706. (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  1707. mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
  1708. (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
  1709. }
  1710. }
  1711. template <ggml_type type>
  1712. void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
  1713. const int id = ggml_cuda_get_device();
  1714. const int nsm = ggml_cuda_info().devices[id].nsm;
  1715. const int cc = ggml_cuda_info().devices[id].cc;
  1716. const int smpbo = ggml_cuda_info().devices[id].smpbo;
  1717. const int mmq_x_max = get_mmq_x_max_host(cc);
  1718. const int mmq_y = get_mmq_y_host(cc);
  1719. const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
  1720. const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
  1721. int mmq_x_best = 0;
  1722. int nparts_best = INT_MAX;
  1723. for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
  1724. const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
  1725. const int nwaves_xy_tiling = ntiles_x*block_num_y;
  1726. const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
  1727. if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) {
  1728. mmq_x_best = mmq_x;
  1729. nparts_best = nparts;
  1730. }
  1731. }
  1732. switch (mmq_x_best) {
  1733. case 8:
  1734. launch_mul_mat_q<type, 8>(ctx, args, stream);
  1735. break;
  1736. case 16:
  1737. launch_mul_mat_q<type, 16>(ctx, args, stream);
  1738. break;
  1739. case 24:
  1740. launch_mul_mat_q<type, 24>(ctx, args, stream);
  1741. break;
  1742. case 32:
  1743. launch_mul_mat_q<type, 32>(ctx, args, stream);
  1744. break;
  1745. case 40:
  1746. launch_mul_mat_q<type, 40>(ctx, args, stream);
  1747. break;
  1748. case 48:
  1749. launch_mul_mat_q<type, 48>(ctx, args, stream);
  1750. break;
  1751. case 56:
  1752. launch_mul_mat_q<type, 56>(ctx, args, stream);
  1753. break;
  1754. case 64:
  1755. launch_mul_mat_q<type, 64>(ctx, args, stream);
  1756. break;
  1757. case 72:
  1758. launch_mul_mat_q<type, 72>(ctx, args, stream);
  1759. break;
  1760. case 80:
  1761. launch_mul_mat_q<type, 80>(ctx, args, stream);
  1762. break;
  1763. case 88:
  1764. launch_mul_mat_q<type, 88>(ctx, args, stream);
  1765. break;
  1766. case 96:
  1767. launch_mul_mat_q<type, 96>(ctx, args, stream);
  1768. break;
  1769. case 104:
  1770. launch_mul_mat_q<type, 104>(ctx, args, stream);
  1771. break;
  1772. case 112:
  1773. launch_mul_mat_q<type, 112>(ctx, args, stream);
  1774. break;
  1775. case 120:
  1776. launch_mul_mat_q<type, 120>(ctx, args, stream);
  1777. break;
  1778. case 128:
  1779. launch_mul_mat_q<type, 128>(ctx, args, stream);
  1780. break;
  1781. default:
  1782. fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best);
  1783. GGML_ASSERT(false);
  1784. break;
  1785. }
  1786. }
  1787. #define DECL_MMQ_CASE(type) \
  1788. template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
  1789. extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
  1790. extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
  1791. extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
  1792. extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
  1793. extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
  1794. extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
  1795. extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
  1796. extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
  1797. extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
  1798. extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
  1799. // -------------------------------------------------------------------------------------------------------------------------
  1800. void ggml_cuda_op_mul_mat_q(
  1801. ggml_backend_cuda_context & ctx,
  1802. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  1803. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  1804. const int64_t src1_padded_row_size, cudaStream_t stream);
  1805. bool ggml_cuda_supports_mmq(enum ggml_type type);