mmq.cu 58 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570
  1. #include "mmq.cuh"
  2. #include "vecdotq.cuh"
  3. typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
  4. typedef void (*load_tiles_cuda_t)(
  5. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  6. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row);
  7. typedef float (*vec_dot_q_mul_mat_cuda_t)(
  8. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  9. const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
  10. typedef void (*dot_kernel_k_t)(const void * __restrict__ vx, const int ib, const int iqs, const float * __restrict__ y, float & v);
  11. typedef void (mul_mat_q_t)(
  12. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  13. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst);
  14. struct mmq_arch_config_t {
  15. int x;
  16. int y;
  17. int nwarps;
  18. };
  19. struct mmq_config_t {
  20. mmq_arch_config_t rdna2;
  21. mmq_arch_config_t rdna1;
  22. mmq_arch_config_t ampere;
  23. mmq_arch_config_t pascal;
  24. };
  25. constexpr mmq_config_t MMQ_CONFIG_Q4_0 = {
  26. // x y nwarps
  27. { 64, 128, 8},
  28. { 64, 64, 8},
  29. #ifdef CUDA_USE_TENSOR_CORES
  30. { 4, 32, 4},
  31. #else
  32. { 64, 128, 4},
  33. #endif // CUDA_USE_TENSOR_CORES
  34. { 64, 64, 8},
  35. };
  36. constexpr mmq_config_t MMQ_CONFIG_Q4_1 = {
  37. // x y nwarps
  38. { 64, 128, 8},
  39. { 64, 64, 8},
  40. #ifdef CUDA_USE_TENSOR_CORES
  41. { 4, 32, 4},
  42. #else
  43. { 64, 128, 4},
  44. #endif // CUDA_USE_TENSOR_CORES
  45. { 64, 64, 8},
  46. };
  47. constexpr mmq_config_t MMQ_CONFIG_Q5_0 = {
  48. // x y nwarps
  49. { 64, 128, 8},
  50. { 64, 64, 8},
  51. #ifdef CUDA_USE_TENSOR_CORES
  52. { 4, 32, 4},
  53. #else
  54. {128, 64, 4},
  55. #endif // CUDA_USE_TENSOR_CORES
  56. { 64, 64, 8},
  57. };
  58. constexpr mmq_config_t MMQ_CONFIG_Q5_1 = {
  59. // x y nwarps
  60. { 64, 128, 8},
  61. { 64, 64, 8},
  62. #ifdef CUDA_USE_TENSOR_CORES
  63. { 4, 32, 4},
  64. #else
  65. {128, 64, 4},
  66. #endif // CUDA_USE_TENSOR_CORES
  67. { 64, 64, 8},
  68. };
  69. constexpr mmq_config_t MMQ_CONFIG_Q8_0 = {
  70. // x y nwarps
  71. { 64, 128, 8},
  72. { 64, 64, 8},
  73. #ifdef CUDA_USE_TENSOR_CORES
  74. { 4, 32, 4},
  75. #else
  76. {128, 64, 4},
  77. #endif // CUDA_USE_TENSOR_CORES
  78. { 64, 64, 8},
  79. };
  80. constexpr mmq_config_t MMQ_CONFIG_Q2_K = {
  81. // x y nwarps
  82. { 64, 128, 8},
  83. {128, 32, 8},
  84. #ifdef CUDA_USE_TENSOR_CORES
  85. { 4, 32, 4},
  86. #else
  87. { 64, 128, 4},
  88. #endif // CUDA_USE_TENSOR_CORES
  89. { 64, 64, 8},
  90. };
  91. constexpr mmq_config_t MMQ_CONFIG_Q3_K = {
  92. // x y nwarps
  93. {128, 64, 8},
  94. { 32, 128, 8},
  95. #ifdef CUDA_USE_TENSOR_CORES
  96. { 4, 32, 4},
  97. #else
  98. {128, 128, 4},
  99. #endif // CUDA_USE_TENSOR_CORES
  100. { 64, 64, 8},
  101. };
  102. constexpr mmq_config_t MMQ_CONFIG_Q4_K = {
  103. // x y nwarps
  104. { 64, 128, 8},
  105. { 32, 64, 8},
  106. #ifdef CUDA_USE_TENSOR_CORES
  107. { 4, 32, 4},
  108. #else
  109. { 64, 128, 4},
  110. #endif // CUDA_USE_TENSOR_CORES
  111. { 64, 64, 8},
  112. };
  113. constexpr mmq_config_t MMQ_CONFIG_Q5_K = {
  114. // x y nwarps
  115. { 64, 128, 8},
  116. { 32, 64, 8},
  117. #ifdef CUDA_USE_TENSOR_CORES
  118. { 4, 32, 4},
  119. #else
  120. { 64, 128, 4},
  121. #endif // CUDA_USE_TENSOR_CORES
  122. { 64, 64, 8},
  123. };
  124. constexpr mmq_config_t MMQ_CONFIG_Q6_K = {
  125. // x y nwarps
  126. { 64, 128, 8},
  127. { 32, 64, 8},
  128. #ifdef CUDA_USE_TENSOR_CORES
  129. { 4, 32, 4},
  130. #else
  131. { 64, 64, 4},
  132. #endif // CUDA_USE_TENSOR_CORES
  133. { 64, 64, 8},
  134. };
  135. // ------------------------------------------------------------
  136. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  137. GGML_UNUSED(x_qh);
  138. GGML_UNUSED(x_sc);
  139. __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
  140. __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
  141. *x_ql = tile_x_qs;
  142. *x_dm = (half2 *) tile_x_d;
  143. }
  144. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
  145. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  146. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  147. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  148. GGML_CUDA_ASSUME(i_offset >= 0);
  149. GGML_CUDA_ASSUME(i_offset < nwarps);
  150. GGML_CUDA_ASSUME(k >= 0);
  151. GGML_CUDA_ASSUME(k < WARP_SIZE);
  152. const int kbx = k / QI4_0;
  153. const int kqsx = k % QI4_0;
  154. const block_q4_0 * bx0 = (const block_q4_0 *) vx;
  155. float * x_dmf = (float *) x_dm;
  156. #pragma unroll
  157. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  158. int i = i0 + i_offset;
  159. if (need_check) {
  160. i = min(i, i_max);
  161. }
  162. const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
  163. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
  164. // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
  165. }
  166. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  167. const int kbxd = k % blocks_per_tile_x_row;
  168. #pragma unroll
  169. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  170. int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
  171. if (need_check) {
  172. i = min(i, i_max);
  173. }
  174. const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  175. x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
  176. }
  177. }
  178. static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
  179. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  180. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  181. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  182. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  183. const float * x_dmf = (const float *) x_dm;
  184. int u[2*VDR_Q4_0_Q8_1_MMQ];
  185. #pragma unroll
  186. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  187. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  188. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
  189. }
  190. return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  191. (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
  192. y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  193. }
  194. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  195. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  196. __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y];
  197. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
  198. *x_ql = tile_x_qs;
  199. *x_dm = tile_x_dm;
  200. }
  201. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
  202. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  203. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  204. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  205. GGML_CUDA_ASSUME(i_offset >= 0);
  206. GGML_CUDA_ASSUME(i_offset < nwarps);
  207. GGML_CUDA_ASSUME(k >= 0);
  208. GGML_CUDA_ASSUME(k < WARP_SIZE);
  209. const int kbx = k / QI4_1;
  210. const int kqsx = k % QI4_1;
  211. const block_q4_1 * bx0 = (const block_q4_1 *) vx;
  212. #pragma unroll
  213. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  214. int i = i0 + i_offset;
  215. if (need_check) {
  216. i = min(i, i_max);
  217. }
  218. const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
  219. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  220. }
  221. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  222. const int kbxd = k % blocks_per_tile_x_row;
  223. #pragma unroll
  224. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  225. int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
  226. if (need_check) {
  227. i = min(i, i_max);
  228. }
  229. const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
  230. x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
  231. }
  232. }
  233. static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
  234. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  235. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  236. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  237. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  238. int u[2*VDR_Q4_1_Q8_1_MMQ];
  239. #pragma unroll
  240. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  241. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  242. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
  243. }
  244. return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  245. (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
  246. y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  247. }
  248. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  249. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  250. __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
  251. __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
  252. *x_ql = tile_x_ql;
  253. *x_dm = (half2 *) tile_x_d;
  254. }
  255. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
  256. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  257. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  258. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  259. GGML_CUDA_ASSUME(i_offset >= 0);
  260. GGML_CUDA_ASSUME(i_offset < nwarps);
  261. GGML_CUDA_ASSUME(k >= 0);
  262. GGML_CUDA_ASSUME(k < WARP_SIZE);
  263. const int kbx = k / QI5_0;
  264. const int kqsx = k % QI5_0;
  265. const block_q5_0 * bx0 = (const block_q5_0 *) vx;
  266. #pragma unroll
  267. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  268. int i = i0 + i_offset;
  269. if (need_check) {
  270. i = min(i, i_max);
  271. }
  272. const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
  273. const int ql = get_int_from_uint8(bxi->qs, kqsx);
  274. const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
  275. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  276. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  277. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  278. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  279. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  280. qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
  281. x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
  282. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  283. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  284. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  285. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  286. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  287. qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
  288. x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
  289. }
  290. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  291. const int kbxd = k % blocks_per_tile_x_row;
  292. float * x_dmf = (float *) x_dm;
  293. #pragma unroll
  294. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  295. int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
  296. if (need_check) {
  297. i = min(i, i_max);
  298. }
  299. const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  300. x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
  301. }
  302. }
  303. static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
  304. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  305. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  306. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  307. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  308. const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
  309. const float * x_dmf = (const float *) x_dm;
  310. const float * y_df = (const float *) y_ds;
  311. int u[2*VDR_Q5_0_Q8_1_MMQ];
  312. #pragma unroll
  313. for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
  314. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  315. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
  316. }
  317. return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
  318. (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  319. }
  320. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  321. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  322. __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
  323. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
  324. *x_ql = tile_x_ql;
  325. *x_dm = tile_x_dm;
  326. }
  327. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
  328. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  329. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  330. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  331. GGML_CUDA_ASSUME(i_offset >= 0);
  332. GGML_CUDA_ASSUME(i_offset < nwarps);
  333. GGML_CUDA_ASSUME(k >= 0);
  334. GGML_CUDA_ASSUME(k < WARP_SIZE);
  335. const int kbx = k / QI5_1;
  336. const int kqsx = k % QI5_1;
  337. const block_q5_1 * bx0 = (const block_q5_1 *) vx;
  338. #pragma unroll
  339. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  340. int i = i0 + i_offset;
  341. if (need_check) {
  342. i = min(i, i_max);
  343. }
  344. const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
  345. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  346. const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
  347. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  348. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  349. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  350. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  351. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  352. x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
  353. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  354. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  355. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  356. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  357. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  358. x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
  359. }
  360. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  361. const int kbxd = k % blocks_per_tile_x_row;
  362. #pragma unroll
  363. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  364. int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
  365. if (need_check) {
  366. i = min(i, i_max);
  367. }
  368. const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
  369. x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
  370. }
  371. }
  372. static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
  373. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  374. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  375. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  376. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  377. const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
  378. int u[2*VDR_Q5_1_Q8_1_MMQ];
  379. #pragma unroll
  380. for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
  381. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  382. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
  383. }
  384. return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  385. (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  386. }
  387. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  388. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  389. __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
  390. __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
  391. *x_ql = tile_x_qs;
  392. *x_dm = (half2 *) tile_x_d;
  393. }
  394. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
  395. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  396. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  397. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  398. GGML_CUDA_ASSUME(i_offset >= 0);
  399. GGML_CUDA_ASSUME(i_offset < nwarps);
  400. GGML_CUDA_ASSUME(k >= 0);
  401. GGML_CUDA_ASSUME(k < WARP_SIZE);
  402. const int kbx = k / QI8_0;
  403. const int kqsx = k % QI8_0;
  404. float * x_dmf = (float *) x_dm;
  405. const block_q8_0 * bx0 = (const block_q8_0 *) vx;
  406. #pragma unroll
  407. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  408. int i = i0 + i_offset;
  409. if (need_check) {
  410. i = min(i, i_max);
  411. }
  412. const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
  413. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
  414. }
  415. const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
  416. const int kbxd = k % blocks_per_tile_x_row;
  417. #pragma unroll
  418. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
  419. int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
  420. if (need_check) {
  421. i = min(i, i_max);
  422. }
  423. const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  424. x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
  425. }
  426. }
  427. static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
  428. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  429. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  430. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  431. const float * x_dmf = (const float *) x_dm;
  432. const float * y_df = (const float *) y_ds;
  433. return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
  434. (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
  435. y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
  436. }
  437. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  438. GGML_UNUSED(x_qh);
  439. __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
  440. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
  441. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
  442. *x_ql = tile_x_ql;
  443. *x_dm = tile_x_dm;
  444. *x_sc = tile_x_sc;
  445. }
  446. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
  447. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  448. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  449. GGML_UNUSED(x_qh);
  450. GGML_CUDA_ASSUME(i_offset >= 0);
  451. GGML_CUDA_ASSUME(i_offset < nwarps);
  452. GGML_CUDA_ASSUME(k >= 0);
  453. GGML_CUDA_ASSUME(k < WARP_SIZE);
  454. const int kbx = k / QI2_K;
  455. const int kqsx = k % QI2_K;
  456. const block_q2_K * bx0 = (const block_q2_K *) vx;
  457. #pragma unroll
  458. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  459. int i = i0 + i_offset;
  460. if (need_check) {
  461. i = min(i, i_max);
  462. }
  463. const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
  464. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  465. }
  466. const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
  467. const int kbxd = k % blocks_per_tile_x_row;
  468. #pragma unroll
  469. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
  470. int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
  471. if (need_check) {
  472. i = min(i, i_max);
  473. }
  474. const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
  475. x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
  476. }
  477. #pragma unroll
  478. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  479. int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
  480. if (need_check) {
  481. i = min(i, i_max);
  482. }
  483. const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
  484. x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
  485. }
  486. }
  487. static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
  488. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  489. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  490. GGML_UNUSED(x_qh);
  491. const int kbx = k / QI2_K;
  492. const int ky = (k % QI2_K) * QR2_K;
  493. const float * y_df = (const float *) y_ds;
  494. int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
  495. const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
  496. const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
  497. #pragma unroll
  498. for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
  499. v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
  500. }
  501. const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
  502. const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
  503. return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
  504. }
  505. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  506. __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
  507. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K];
  508. __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2];
  509. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
  510. *x_ql = tile_x_ql;
  511. *x_dm = tile_x_dm;
  512. *x_qh = tile_x_qh;
  513. *x_sc = tile_x_sc;
  514. }
  515. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
  516. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  517. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  518. GGML_CUDA_ASSUME(i_offset >= 0);
  519. GGML_CUDA_ASSUME(i_offset < nwarps);
  520. GGML_CUDA_ASSUME(k >= 0);
  521. GGML_CUDA_ASSUME(k < WARP_SIZE);
  522. const int kbx = k / QI3_K;
  523. const int kqsx = k % QI3_K;
  524. const block_q3_K * bx0 = (const block_q3_K *) vx;
  525. #pragma unroll
  526. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  527. int i = i0 + i_offset;
  528. if (need_check) {
  529. i = min(i, i_max);
  530. }
  531. const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
  532. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
  533. }
  534. const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
  535. const int kbxd = k % blocks_per_tile_x_row;
  536. float * x_dmf = (float *) x_dm;
  537. #pragma unroll
  538. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
  539. int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
  540. if (need_check) {
  541. i = min(i, i_max);
  542. }
  543. const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
  544. x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
  545. }
  546. #pragma unroll
  547. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
  548. int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
  549. if (need_check) {
  550. i = min(i, i_max);
  551. }
  552. const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
  553. // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
  554. x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
  555. }
  556. #pragma unroll
  557. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  558. int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
  559. if (need_check) {
  560. i = min(i, i_max);
  561. }
  562. const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
  563. const int ksc = k % (QI3_K/4);
  564. const int ksc_low = ksc % (QI3_K/8);
  565. const int shift_low = 4 * (ksc / (QI3_K/8));
  566. const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  567. const int ksc_high = QI3_K/8;
  568. const int shift_high = 2 * ksc;
  569. const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  570. const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
  571. x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
  572. }
  573. }
  574. static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
  575. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  576. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  577. const int kbx = k / QI3_K;
  578. const int ky = (k % QI3_K) * QR3_K;
  579. const float * x_dmf = (const float *) x_dm;
  580. const float * y_df = (const float *) y_ds;
  581. const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  582. int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
  583. #pragma unroll
  584. for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
  585. const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
  586. const int shift = 2 * ((ky % 32) / 8);
  587. const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
  588. const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
  589. const int vlh = (vh << 2) & 0x04040404;
  590. v[l] = __vsubss4(vll, vlh);
  591. }
  592. const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
  593. return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
  594. }
  595. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  596. GGML_UNUSED(x_qh);
  597. __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
  598. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
  599. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
  600. *x_ql = tile_x_ql;
  601. *x_dm = tile_x_dm;
  602. *x_sc = tile_x_sc;
  603. }
  604. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
  605. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  606. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  607. GGML_UNUSED(x_qh);
  608. GGML_CUDA_ASSUME(i_offset >= 0);
  609. GGML_CUDA_ASSUME(i_offset < nwarps);
  610. GGML_CUDA_ASSUME(k >= 0);
  611. GGML_CUDA_ASSUME(k < WARP_SIZE);
  612. const int kbx = k / QI4_K; // == 0 if QK_K == 256
  613. const int kqsx = k % QI4_K; // == k if QK_K == 256
  614. const block_q4_K * bx0 = (const block_q4_K *) vx;
  615. #pragma unroll
  616. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  617. int i = i0 + i_offset;
  618. if (need_check) {
  619. i = min(i, i_max);
  620. }
  621. const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
  622. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  623. }
  624. const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
  625. const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
  626. #pragma unroll
  627. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
  628. int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
  629. if (need_check) {
  630. i = min(i, i_max);
  631. }
  632. const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
  633. #if QK_K == 256
  634. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
  635. #else
  636. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
  637. #endif
  638. }
  639. #pragma unroll
  640. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  641. int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
  642. if (need_check) {
  643. i = min(i, i_max);
  644. }
  645. const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
  646. const int * scales = (const int *) bxi->scales;
  647. const int ksc = k % (WARP_SIZE/8);
  648. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  649. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  650. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  651. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  652. }
  653. }
  654. static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
  655. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  656. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  657. GGML_UNUSED(x_qh);
  658. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
  659. const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
  660. return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
  661. x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
  662. }
  663. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  664. GGML_UNUSED(x_qh);
  665. __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
  666. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
  667. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
  668. *x_ql = tile_x_ql;
  669. *x_dm = tile_x_dm;
  670. *x_sc = tile_x_sc;
  671. }
  672. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
  673. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  674. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  675. GGML_UNUSED(x_qh);
  676. GGML_CUDA_ASSUME(i_offset >= 0);
  677. GGML_CUDA_ASSUME(i_offset < nwarps);
  678. GGML_CUDA_ASSUME(k >= 0);
  679. GGML_CUDA_ASSUME(k < WARP_SIZE);
  680. const int kbx = k / QI5_K; // == 0 if QK_K == 256
  681. const int kqsx = k % QI5_K; // == k if QK_K == 256
  682. const block_q5_K * bx0 = (const block_q5_K *) vx;
  683. #pragma unroll
  684. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  685. int i = i0 + i_offset;
  686. if (need_check) {
  687. i = min(i, i_max);
  688. }
  689. const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
  690. const int ky = QR5_K*kqsx;
  691. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  692. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  693. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  694. const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
  695. const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  696. const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  697. const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
  698. const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
  699. x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  700. x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  701. }
  702. const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
  703. const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
  704. #pragma unroll
  705. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
  706. int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
  707. if (need_check) {
  708. i = min(i, i_max);
  709. }
  710. const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
  711. #if QK_K == 256
  712. x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
  713. #endif
  714. }
  715. #pragma unroll
  716. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  717. int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
  718. if (need_check) {
  719. i = min(i, i_max);
  720. }
  721. const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
  722. const int * scales = (const int *) bxi->scales;
  723. const int ksc = k % (WARP_SIZE/8);
  724. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  725. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  726. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  727. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  728. }
  729. }
  730. static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
  731. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  732. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  733. GGML_UNUSED(x_qh);
  734. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
  735. const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
  736. const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
  737. return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
  738. x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
  739. }
  740. template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
  741. GGML_UNUSED(x_qh);
  742. __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
  743. __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
  744. __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
  745. *x_ql = tile_x_ql;
  746. *x_dm = tile_x_dm;
  747. *x_sc = tile_x_sc;
  748. }
  749. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
  750. const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  751. int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
  752. GGML_UNUSED(x_qh);
  753. GGML_CUDA_ASSUME(i_offset >= 0);
  754. GGML_CUDA_ASSUME(i_offset < nwarps);
  755. GGML_CUDA_ASSUME(k >= 0);
  756. GGML_CUDA_ASSUME(k < WARP_SIZE);
  757. const int kbx = k / QI6_K; // == 0 if QK_K == 256
  758. const int kqsx = k % QI6_K; // == k if QK_K == 256
  759. const block_q6_K * bx0 = (const block_q6_K *) vx;
  760. #pragma unroll
  761. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  762. int i = i0 + i_offset;
  763. if (need_check) {
  764. i = min(i, i_max);
  765. }
  766. const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
  767. const int ky = QR6_K*kqsx;
  768. const int ql = get_int_from_uint8(bxi->ql, kqsx);
  769. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  770. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  771. const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
  772. const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
  773. const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
  774. const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
  775. const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
  776. x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  777. x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  778. }
  779. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  780. const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
  781. float * x_dmf = (float *) x_dm;
  782. #pragma unroll
  783. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  784. int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
  785. if (need_check) {
  786. i = min(i, i_max);
  787. }
  788. const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
  789. x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
  790. }
  791. #pragma unroll
  792. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  793. int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
  794. if (need_check) {
  795. i = min(i, i_max);
  796. }
  797. const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
  798. x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
  799. }
  800. }
  801. static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
  802. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  803. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
  804. GGML_UNUSED(x_qh);
  805. const float * x_dmf = (const float *) x_dm;
  806. const float * y_df = (const float *) y_ds;
  807. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
  808. const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
  809. const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
  810. return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
  811. }
  812. template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
  813. allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
  814. static __device__ __forceinline__ void mul_mat_q(
  815. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  816. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  817. const block_q_t * x = (const block_q_t *) vx;
  818. const block_q8_1 * y = (const block_q8_1 *) vy;
  819. const int blocks_per_row_x = ncols_x / qk;
  820. const int blocks_per_col_y = nrows_y / QK8_1;
  821. const int blocks_per_warp = WARP_SIZE / qi;
  822. const int & ncols_dst = ncols_y;
  823. const int row_dst_0 = blockIdx.x*mmq_y;
  824. const int & row_x_0 = row_dst_0;
  825. const int col_dst_0 = blockIdx.y*mmq_x;
  826. const int & col_y_0 = col_dst_0;
  827. int * tile_x_ql = nullptr;
  828. half2 * tile_x_dm = nullptr;
  829. int * tile_x_qh = nullptr;
  830. int * tile_x_sc = nullptr;
  831. allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
  832. __shared__ int tile_y_qs[mmq_x * WARP_SIZE];
  833. __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
  834. float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
  835. for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
  836. load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
  837. threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
  838. #pragma unroll
  839. for (int ir = 0; ir < qr; ++ir) {
  840. const int kqs = ir*WARP_SIZE + threadIdx.x;
  841. const int kbxd = kqs / QI8_1;
  842. #pragma unroll
  843. for (int i = 0; i < mmq_x; i += nwarps) {
  844. const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
  845. const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
  846. const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
  847. tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
  848. }
  849. #pragma unroll
  850. for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
  851. const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
  852. const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
  853. const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
  854. // if the sum is not needed it's faster to transform the scale to f32 ahead of time
  855. const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
  856. half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
  857. if (need_sum) {
  858. *dsi_dst = *dsi_src;
  859. } else {
  860. float * dfi_dst = (float *) dsi_dst;
  861. *dfi_dst = __low2float(*dsi_src);
  862. }
  863. }
  864. __syncthreads();
  865. // #pragma unroll // unrolling this loop causes too much register pressure
  866. for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
  867. #pragma unroll
  868. for (int j = 0; j < mmq_x; j += nwarps) {
  869. #pragma unroll
  870. for (int i = 0; i < mmq_y; i += WARP_SIZE) {
  871. sum[i/WARP_SIZE][j/nwarps] += vec_dot(
  872. tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
  873. threadIdx.x + i, threadIdx.y + j, k);
  874. }
  875. }
  876. }
  877. __syncthreads();
  878. }
  879. }
  880. #pragma unroll
  881. for (int j = 0; j < mmq_x; j += nwarps) {
  882. const int col_dst = col_dst_0 + j + threadIdx.y;
  883. if (col_dst >= ncols_dst) {
  884. return;
  885. }
  886. #pragma unroll
  887. for (int i = 0; i < mmq_y; i += WARP_SIZE) {
  888. const int row_dst = row_dst_0 + threadIdx.x + i;
  889. if (row_dst >= nrows_dst) {
  890. continue;
  891. }
  892. dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
  893. }
  894. }
  895. }
  896. static constexpr __device__ mmq_arch_config_t get_arch_config_device(mmq_config_t mmq_config) {
  897. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  898. #if defined(RDNA3) || defined(RDNA2)
  899. return mmq_config.rdna2;
  900. #else
  901. return mmq_config.rdna1;
  902. #endif // defined(RDNA3) || defined(RDNA2)
  903. #else
  904. #if __CUDA_ARCH__ >= CC_VOLTA
  905. return mmq_config.ampere;
  906. #else
  907. return mmq_config.pascal;
  908. #endif // __CUDA_ARCH__ >= CC_VOLTA
  909. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  910. }
  911. template <bool need_check> static __global__ void
  912. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  913. #if defined(RDNA3) || defined(RDNA2)
  914. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_0.rdna2.nwarps, 2)
  915. #endif // defined(RDNA3) || defined(RDNA2)
  916. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  917. mul_mat_q4_0(
  918. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  919. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  920. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  921. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_0);
  922. mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_0<arch_config.y>,
  923. load_tiles_q4_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
  924. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  925. #else
  926. GGML_UNUSED(get_arch_config_device);
  927. GGML_UNUSED(vec_dot_q4_0_q8_1_mul_mat);
  928. NO_DEVICE_CODE;
  929. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  930. }
  931. template <bool need_check> static __global__ void
  932. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  933. #if defined(RDNA3) || defined(RDNA2)
  934. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_1.rdna2.nwarps, 2)
  935. #endif // defined(RDNA3) || defined(RDNA2)
  936. #elif __CUDA_ARCH__ < CC_VOLTA
  937. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_1.pascal.nwarps, 2)
  938. #endif // __CUDA_ARCH__ < CC_VOLTA
  939. mul_mat_q4_1(
  940. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  941. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  942. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  943. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_1);
  944. mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_1<arch_config.y>,
  945. load_tiles_q4_1<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
  946. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  947. #else
  948. GGML_UNUSED(get_arch_config_device);
  949. GGML_UNUSED(vec_dot_q4_1_q8_1_mul_mat);
  950. NO_DEVICE_CODE;
  951. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  952. }
  953. template <bool need_check> static __global__ void
  954. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  955. #if defined(RDNA3) || defined(RDNA2)
  956. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_0.rdna2.nwarps, 2)
  957. #endif // defined(RDNA3) || defined(RDNA2)
  958. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  959. mul_mat_q5_0(
  960. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  961. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  962. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  963. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_0);
  964. mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q5_0<arch_config.y>,
  965. load_tiles_q5_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
  966. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  967. #else
  968. GGML_UNUSED(get_arch_config_device);
  969. GGML_UNUSED(vec_dot_q5_0_q8_1_mul_mat);
  970. NO_DEVICE_CODE;
  971. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  972. }
  973. template <bool need_check> static __global__ void
  974. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  975. #if defined(RDNA3) || defined(RDNA2)
  976. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_1.rdna2.nwarps, 2)
  977. #endif // defined(RDNA3) || defined(RDNA2)
  978. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  979. mul_mat_q5_1(
  980. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  981. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  982. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  983. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_1);
  984. mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q5_1<arch_config.y>,
  985. load_tiles_q5_1<arch_config.y, arch_config.nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
  986. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  987. #else
  988. GGML_UNUSED(get_arch_config_device);
  989. GGML_UNUSED(vec_dot_q5_1_q8_1_mul_mat);
  990. NO_DEVICE_CODE;
  991. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  992. }
  993. template <bool need_check> static __global__ void
  994. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  995. #if defined(RDNA3) || defined(RDNA2)
  996. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q8_0.rdna2.nwarps, 2)
  997. #endif // defined(RDNA3) || defined(RDNA2)
  998. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  999. mul_mat_q8_0(
  1000. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1001. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1002. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  1003. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q8_0);
  1004. mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q8_0<arch_config.y>,
  1005. load_tiles_q8_0<arch_config.y, arch_config.nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
  1006. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1007. #else
  1008. GGML_UNUSED(get_arch_config_device);
  1009. GGML_UNUSED(vec_dot_q8_0_q8_1_mul_mat);
  1010. NO_DEVICE_CODE;
  1011. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  1012. }
  1013. template <bool need_check> static __global__ void
  1014. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1015. #if defined(RDNA3) || defined(RDNA2)
  1016. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q2_K.rdna2.nwarps, 2)
  1017. #endif // defined(RDNA3) || defined(RDNA2)
  1018. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1019. mul_mat_q2_K(
  1020. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1021. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1022. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  1023. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q2_K);
  1024. mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q2_K<arch_config.y>,
  1025. load_tiles_q2_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
  1026. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1027. #else
  1028. GGML_UNUSED(get_arch_config_device);
  1029. GGML_UNUSED(vec_dot_q2_K_q8_1_mul_mat);
  1030. NO_DEVICE_CODE;
  1031. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  1032. }
  1033. template <bool need_check> static __global__ void
  1034. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1035. #if defined(RDNA3) || defined(RDNA2)
  1036. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q3_K.rdna2.nwarps, 2)
  1037. #endif // defined(RDNA3) || defined(RDNA2)
  1038. #elif __CUDA_ARCH__ < CC_VOLTA
  1039. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q3_K.pascal.nwarps, 2)
  1040. #endif // __CUDA_ARCH__ < CC_VOLTA
  1041. mul_mat_q3_K(
  1042. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1043. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1044. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  1045. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q3_K);
  1046. mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q3_K<arch_config.y>,
  1047. load_tiles_q3_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
  1048. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1049. #else
  1050. GGML_UNUSED(get_arch_config_device);
  1051. GGML_UNUSED(vec_dot_q3_K_q8_1_mul_mat);
  1052. NO_DEVICE_CODE;
  1053. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  1054. }
  1055. template <bool need_check> static __global__ void
  1056. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1057. #if defined(RDNA3) || defined(RDNA2)
  1058. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.rdna2.nwarps, 2)
  1059. #endif // defined(RDNA3) || defined(RDNA2)
  1060. #elif __CUDA_ARCH__ < CC_VOLTA
  1061. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.pascal.nwarps, 2)
  1062. #endif // __CUDA_ARCH__ < CC_VOLTA
  1063. mul_mat_q4_K(
  1064. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1065. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1066. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  1067. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q4_K);
  1068. mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q4_K<arch_config.y>,
  1069. load_tiles_q4_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
  1070. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1071. #else
  1072. GGML_UNUSED(get_arch_config_device);
  1073. GGML_UNUSED(vec_dot_q4_K_q8_1_mul_mat);
  1074. NO_DEVICE_CODE;
  1075. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  1076. }
  1077. template <bool need_check> static __global__ void
  1078. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1079. #if defined(RDNA3) || defined(RDNA2)
  1080. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q5_K.rdna2.nwarps, 2)
  1081. #endif // defined(RDNA3) || defined(RDNA2)
  1082. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1083. mul_mat_q5_K(
  1084. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1085. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1086. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  1087. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q5_K);
  1088. mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q5_K<arch_config.y>,
  1089. load_tiles_q5_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
  1090. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1091. #else
  1092. GGML_UNUSED(get_arch_config_device);
  1093. GGML_UNUSED(vec_dot_q5_K_q8_1_mul_mat);
  1094. NO_DEVICE_CODE;
  1095. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  1096. }
  1097. template <bool need_check> static __global__ void
  1098. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  1099. #if defined(RDNA3) || defined(RDNA2)
  1100. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q6_K.rdna2.nwarps, 2)
  1101. #endif // defined(RDNA3) || defined(RDNA2)
  1102. #elif __CUDA_ARCH__ < CC_VOLTA
  1103. __launch_bounds__(WARP_SIZE*MMQ_CONFIG_Q4_K.pascal.nwarps, 2)
  1104. #endif // __CUDA_ARCH__ < CC_VOLTA
  1105. mul_mat_q6_K(
  1106. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1107. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst) {
  1108. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  1109. constexpr mmq_arch_config_t arch_config = get_arch_config_device(MMQ_CONFIG_Q6_K);
  1110. mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, arch_config.x, arch_config.y, arch_config.nwarps, allocate_tiles_q6_K<arch_config.y>,
  1111. load_tiles_q6_K<arch_config.y, arch_config.nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
  1112. (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
  1113. #else
  1114. GGML_UNUSED(get_arch_config_device);
  1115. GGML_UNUSED(vec_dot_q6_K_q8_1_mul_mat);
  1116. NO_DEVICE_CODE;
  1117. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  1118. }
  1119. #define MMQ_SWITCH_CASE(type_suffix) \
  1120. case GGML_TYPE_Q##type_suffix: if (row_diff % arch_config.y == 0) { \
  1121. const bool need_check = false; \
  1122. mul_mat_q##type_suffix<need_check><<<block_nums, block_dims, 0, stream>>> \
  1123. (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst); \
  1124. } else { \
  1125. const bool need_check = true; \
  1126. mul_mat_q##type_suffix<need_check><<<block_nums, block_dims, 0, stream>>> \
  1127. (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst); \
  1128. } break; \
  1129. void ggml_cuda_op_mul_mat_q(
  1130. ggml_backend_cuda_context & ctx,
  1131. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  1132. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  1133. const int64_t src1_padded_row_size, cudaStream_t stream) {
  1134. const int64_t ne00 = src0->ne[0];
  1135. const int64_t ne10 = src1->ne[0];
  1136. GGML_ASSERT(ne10 % QK8_1 == 0);
  1137. const int64_t ne0 = dst->ne[0];
  1138. const int64_t row_diff = row_high - row_low;
  1139. int id = ggml_cuda_get_device();
  1140. const int compute_capability = ggml_cuda_info().devices[id].cc;
  1141. // the main device has a larger memory buffer to hold the results from all GPUs
  1142. // nrows_dst == nrows of the matrix that the kernel writes into
  1143. const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
  1144. mmq_config_t mmq_config;
  1145. switch (src0->type) {
  1146. case GGML_TYPE_Q4_0:
  1147. mmq_config = MMQ_CONFIG_Q4_0;
  1148. break;
  1149. case GGML_TYPE_Q4_1:
  1150. mmq_config = MMQ_CONFIG_Q4_1;
  1151. break;
  1152. case GGML_TYPE_Q5_0:
  1153. mmq_config = MMQ_CONFIG_Q5_0;
  1154. break;
  1155. case GGML_TYPE_Q5_1:
  1156. mmq_config = MMQ_CONFIG_Q5_1;
  1157. break;
  1158. case GGML_TYPE_Q8_0:
  1159. mmq_config = MMQ_CONFIG_Q8_0;
  1160. break;
  1161. case GGML_TYPE_Q2_K:
  1162. mmq_config = MMQ_CONFIG_Q2_K;
  1163. break;
  1164. case GGML_TYPE_Q3_K:
  1165. mmq_config = MMQ_CONFIG_Q3_K;
  1166. break;
  1167. case GGML_TYPE_Q4_K:
  1168. mmq_config = MMQ_CONFIG_Q4_K;
  1169. break;
  1170. case GGML_TYPE_Q5_K:
  1171. mmq_config = MMQ_CONFIG_Q5_K;
  1172. break;
  1173. case GGML_TYPE_Q6_K:
  1174. mmq_config = MMQ_CONFIG_Q6_K;
  1175. break;
  1176. default:
  1177. GGML_ASSERT(false);
  1178. break;
  1179. }
  1180. mmq_arch_config_t arch_config;
  1181. if (compute_capability >= CC_RDNA2) {
  1182. arch_config = mmq_config.rdna2;
  1183. } else if (compute_capability >= CC_OFFSET_AMD) {
  1184. arch_config = mmq_config.rdna1;
  1185. } else if (compute_capability >= CC_VOLTA) {
  1186. arch_config = mmq_config.ampere;
  1187. } else if (compute_capability >= MIN_CC_DP4A) {
  1188. arch_config = mmq_config.pascal;
  1189. } else {
  1190. GGML_ASSERT(false);
  1191. }
  1192. const int block_num_x = (row_diff + arch_config.y - 1) / arch_config.y;
  1193. const int block_num_y = (src1_ncols + arch_config.x - 1) / arch_config.x;
  1194. const dim3 block_nums(block_num_x, block_num_y, 1);
  1195. const dim3 block_dims(WARP_SIZE, arch_config.nwarps, 1);
  1196. switch (src0->type) {
  1197. MMQ_SWITCH_CASE(4_0)
  1198. MMQ_SWITCH_CASE(4_1)
  1199. MMQ_SWITCH_CASE(5_0)
  1200. MMQ_SWITCH_CASE(5_1)
  1201. MMQ_SWITCH_CASE(8_0)
  1202. MMQ_SWITCH_CASE(2_K)
  1203. MMQ_SWITCH_CASE(3_K)
  1204. MMQ_SWITCH_CASE(4_K)
  1205. MMQ_SWITCH_CASE(5_K)
  1206. MMQ_SWITCH_CASE(6_K)
  1207. default:
  1208. GGML_ASSERT(false);
  1209. break;
  1210. }
  1211. GGML_UNUSED(src1);
  1212. GGML_UNUSED(dst);
  1213. GGML_UNUSED(src1_ddf_i);
  1214. }
  1215. bool ggml_cuda_supports_mmq(enum ggml_type type) {
  1216. switch (type) {
  1217. case GGML_TYPE_Q4_0:
  1218. case GGML_TYPE_Q4_1:
  1219. case GGML_TYPE_Q5_0:
  1220. case GGML_TYPE_Q5_1:
  1221. case GGML_TYPE_Q8_0:
  1222. case GGML_TYPE_Q2_K:
  1223. case GGML_TYPE_Q3_K:
  1224. case GGML_TYPE_Q4_K:
  1225. case GGML_TYPE_Q5_K:
  1226. case GGML_TYPE_Q6_K:
  1227. return true;
  1228. default:
  1229. return false;
  1230. }
  1231. }