mmq.cuh 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309
  1. #include "common.cuh"
  2. #include "vecdotq.cuh"
  3. #include <climits>
  4. #include <cstdint>
  5. typedef void (*load_tiles_mmq_t)(
  6. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  7. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
  8. typedef void (*vec_dot_mmq_t)(
  9. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  10. const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, float * __restrict__ sum, const int & k0);
  11. struct tile_x_sizes {
  12. int ql;
  13. int dm;
  14. int qh;
  15. int sc;
  16. };
  17. // get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  18. static constexpr __device__ int get_mmq_x_max_device() {
  19. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  20. return 64;
  21. #else
  22. #if __CUDA_ARCH__ >= CC_VOLTA
  23. #ifdef CUDA_USE_TENSOR_CORES
  24. return MMQ_MAX_BATCH_SIZE;
  25. #else
  26. return 128;
  27. #endif // CUDA_USE_TENSOR_CORES
  28. #else
  29. return 64;
  30. #endif // __CUDA_ARCH__ >= CC_VOLTA
  31. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  32. }
  33. // get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  34. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  35. static constexpr __device__ int get_mmq_y_device(int mmq_x) {
  36. return mmq_x >= 32 ? 128 : 64;
  37. }
  38. #else
  39. #if __CUDA_ARCH__ >= CC_VOLTA
  40. static constexpr __device__ int get_mmq_y_device(int mmq_x) {
  41. return mmq_x >= 32 ? 128 : 64;
  42. }
  43. #else
  44. static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
  45. return 64;
  46. }
  47. #endif // __CUDA_ARCH__ >= CC_VOLTA
  48. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  49. #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0, 0}
  50. #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0, 0}
  51. #define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0, 0}
  52. #define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0, 0}
  53. #define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0, 0}
  54. #define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0, mmq_y*WARP_SIZE/4 + mmq_y/4}
  55. #define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4}
  56. #define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
  57. #define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
  58. #define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
  59. #define GET_TILE_X_SIZES_BODY \
  60. return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
  61. type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \
  62. type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \
  63. type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \
  64. type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \
  65. type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \
  66. type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \
  67. type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
  68. type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
  69. type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
  70. tile_x_sizes{0, 0, 0, 0}
  71. static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
  72. GET_TILE_X_SIZES_BODY;
  73. }
  74. template <int mmq_y>
  75. static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
  76. GET_TILE_X_SIZES_BODY;
  77. }
  78. // ------------------------------------------------------------
  79. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
  80. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  81. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  82. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  83. const int kbx = threadIdx.x / QI4_0;
  84. const int kqsx = threadIdx.x % QI4_0;
  85. float * x_dmf = (float *) x_dm;
  86. #pragma unroll
  87. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  88. int i = i0 + threadIdx.y;
  89. if (need_check) {
  90. i = min(i, i_max);
  91. }
  92. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
  93. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  94. }
  95. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  96. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  97. #pragma unroll
  98. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  99. int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
  100. if (need_check) {
  101. i = min(i, i_max);
  102. }
  103. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
  104. x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
  105. }
  106. }
  107. template <int mmq_x, int mmq_y, int nwarps>
  108. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
  109. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  110. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  111. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  112. #pragma unroll
  113. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  114. const int j = j0 + threadIdx.y;
  115. #pragma unroll
  116. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  117. const int i = i0 + threadIdx.x;
  118. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  119. const float * x_dmf = (const float *) x_dm;
  120. int u[2*VDR_Q4_0_Q8_1_MMQ];
  121. #pragma unroll
  122. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  123. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  124. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
  125. }
  126. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  127. (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
  128. y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  129. }
  130. }
  131. }
  132. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
  133. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  134. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  135. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  136. const int kbx = threadIdx.x / QI4_1;
  137. const int kqsx = threadIdx.x % QI4_1;
  138. #pragma unroll
  139. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  140. int i = i0 + threadIdx.y;
  141. if (need_check) {
  142. i = min(i, i_max);
  143. }
  144. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
  145. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  146. }
  147. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  148. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  149. #pragma unroll
  150. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  151. int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
  152. if (need_check) {
  153. i = min(i, i_max);
  154. }
  155. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
  156. x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
  157. }
  158. }
  159. template <int mmq_x, int mmq_y, int nwarps>
  160. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
  161. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  162. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  163. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  164. #pragma unroll
  165. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  166. const int j = j0 + threadIdx.y;
  167. #pragma unroll
  168. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  169. const int i = i0 + threadIdx.x;
  170. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  171. int u[2*VDR_Q4_1_Q8_1_MMQ];
  172. #pragma unroll
  173. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  174. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  175. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
  176. }
  177. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  178. (&x_ql[i * (WARP_SIZE + 1) + k0], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
  179. y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  180. }
  181. }
  182. }
  183. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
  184. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  185. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  186. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  187. const int kbx = threadIdx.x / QI5_0;
  188. const int kqsx = threadIdx.x % QI5_0;
  189. #pragma unroll
  190. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  191. int i = i0 + threadIdx.y;
  192. if (need_check) {
  193. i = min(i, i_max);
  194. }
  195. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
  196. const int ql = get_int_from_uint8(bxi->qs, kqsx);
  197. const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
  198. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  199. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  200. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  201. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  202. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  203. qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
  204. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  205. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  206. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  207. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  208. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  209. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  210. qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
  211. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  212. }
  213. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  214. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  215. float * x_dmf = (float *) x_dm;
  216. #pragma unroll
  217. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  218. int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
  219. if (need_check) {
  220. i = min(i, i_max);
  221. }
  222. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
  223. x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
  224. }
  225. }
  226. template <int mmq_x, int mmq_y, int nwarps>
  227. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
  228. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  229. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  230. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  231. #pragma unroll
  232. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  233. const int j = j0 + threadIdx.y;
  234. #pragma unroll
  235. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  236. const int i = i0 + threadIdx.x;
  237. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  238. const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
  239. const float * x_dmf = (const float *) x_dm;
  240. const float * y_df = (const float *) y_ds;
  241. int u[2*VDR_Q5_0_Q8_1_MMQ];
  242. #pragma unroll
  243. for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
  244. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  245. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
  246. }
  247. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
  248. (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  249. }
  250. }
  251. }
  252. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
  253. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  254. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  255. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  256. const int kbx = threadIdx.x / QI5_1;
  257. const int kqsx = threadIdx.x % QI5_1;
  258. #pragma unroll
  259. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  260. int i = i0 + threadIdx.y;
  261. if (need_check) {
  262. i = min(i, i_max);
  263. }
  264. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
  265. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  266. const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
  267. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  268. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  269. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  270. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  271. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  272. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  273. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  274. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  275. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  276. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  277. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  278. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  279. }
  280. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  281. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  282. #pragma unroll
  283. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  284. int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
  285. if (need_check) {
  286. i = min(i, i_max);
  287. }
  288. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
  289. x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
  290. }
  291. }
  292. template <int mmq_x, int mmq_y, int nwarps>
  293. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
  294. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  295. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  296. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  297. #pragma unroll
  298. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  299. const int j = j0 + threadIdx.y;
  300. #pragma unroll
  301. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  302. const int i = i0 + threadIdx.x;
  303. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  304. const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k0/QI5_1;
  305. int u[2*VDR_Q5_1_Q8_1_MMQ];
  306. #pragma unroll
  307. for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
  308. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  309. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
  310. }
  311. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  312. (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k0], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  313. }
  314. }
  315. }
  316. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
  317. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  318. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  319. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  320. const int kbx = threadIdx.x / QI8_0;
  321. const int kqsx = threadIdx.x % QI8_0;
  322. float * x_dmf = (float *) x_dm;
  323. #pragma unroll
  324. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  325. int i = i0 + threadIdx.y;
  326. if (need_check) {
  327. i = min(i, i_max);
  328. }
  329. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
  330. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
  331. }
  332. const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
  333. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  334. #pragma unroll
  335. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
  336. int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
  337. if (need_check) {
  338. i = min(i, i_max);
  339. }
  340. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
  341. x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
  342. }
  343. }
  344. template <int mmq_x, int mmq_y, int nwarps>
  345. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
  346. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  347. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  348. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  349. #pragma unroll
  350. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  351. const int j = j0 + threadIdx.y;
  352. #pragma unroll
  353. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  354. const int i = i0 + threadIdx.x;
  355. const float * x_dmf = (const float *) x_dm;
  356. const float * y_df = (const float *) y_ds;
  357. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
  358. (&x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[j * WARP_SIZE + k0], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
  359. y_df[j * (WARP_SIZE/QI8_1) + k0/QI8_1]);
  360. }
  361. }
  362. }
  363. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
  364. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  365. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  366. GGML_UNUSED(x_qh);
  367. const int kbx = threadIdx.x / QI2_K;
  368. const int kqsx = threadIdx.x % QI2_K;
  369. #pragma unroll
  370. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  371. int i = i0 + threadIdx.y;
  372. if (need_check) {
  373. i = min(i, i_max);
  374. }
  375. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
  376. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  377. }
  378. const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
  379. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  380. #pragma unroll
  381. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
  382. int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  383. if (need_check) {
  384. i = min(i, i_max);
  385. }
  386. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd;
  387. x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
  388. }
  389. #pragma unroll
  390. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  391. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  392. if (need_check) {
  393. i = min(i, i_max);
  394. }
  395. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4);
  396. x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4));
  397. }
  398. }
  399. template <int mmq_x, int mmq_y, int nwarps>
  400. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
  401. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  402. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  403. GGML_UNUSED(x_qh);
  404. #pragma unroll
  405. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  406. const int j = j0 + threadIdx.y;
  407. #pragma unroll
  408. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  409. const int i = i0 + threadIdx.x;
  410. const int kbx = k0 / QI2_K;
  411. const int ky = (k0 % QI2_K) * QR2_K;
  412. const float * y_df = (const float *) y_ds;
  413. int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
  414. const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
  415. const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
  416. #pragma unroll
  417. for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
  418. v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
  419. }
  420. const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
  421. const int index_y = j * WARP_SIZE + (QR2_K*k0) % WARP_SIZE;
  422. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
  423. v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
  424. }
  425. }
  426. }
  427. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
  428. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  429. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  430. const int kbx = threadIdx.x / QI3_K;
  431. const int kqsx = threadIdx.x % QI3_K;
  432. #pragma unroll
  433. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  434. int i = i0 + threadIdx.y;
  435. if (need_check) {
  436. i = min(i, i_max);
  437. }
  438. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
  439. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  440. }
  441. const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
  442. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  443. float * x_dmf = (float *) x_dm;
  444. #pragma unroll
  445. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
  446. int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  447. if (need_check) {
  448. i = min(i, i_max);
  449. }
  450. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
  451. x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
  452. }
  453. #pragma unroll
  454. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
  455. int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2);
  456. if (need_check) {
  457. i = min(i, i_max);
  458. }
  459. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2);
  460. // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
  461. x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2));
  462. }
  463. #pragma unroll
  464. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  465. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  466. if (need_check) {
  467. i = min(i, i_max);
  468. }
  469. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
  470. const int ksc = threadIdx.x % (QI3_K/4);
  471. const int ksc_low = ksc % (QI3_K/8);
  472. const int shift_low = 4 * (ksc / (QI3_K/8));
  473. const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  474. const int ksc_high = QI3_K/8;
  475. const int shift_high = 2 * ksc;
  476. const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  477. const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
  478. x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
  479. }
  480. }
  481. template <int mmq_x, int mmq_y, int nwarps>
  482. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
  483. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  484. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  485. #pragma unroll
  486. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  487. const int j = j0 + threadIdx.y;
  488. #pragma unroll
  489. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  490. const int i = i0 + threadIdx.x;
  491. const int kbx = k0 / QI3_K;
  492. const int ky = (k0 % QI3_K) * QR3_K;
  493. const float * x_dmf = (const float *) x_dm;
  494. const float * y_df = (const float *) y_ds;
  495. const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  496. int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
  497. #pragma unroll
  498. for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
  499. const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
  500. const int shift = 2 * ((ky % 32) / 8);
  501. const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
  502. const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
  503. const int vlh = (vh << 2) & 0x04040404;
  504. v[l] = __vsubss4(vll, vlh);
  505. }
  506. const int index_y = j * WARP_SIZE + (k0*QR3_K) % WARP_SIZE;
  507. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
  508. v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
  509. }
  510. }
  511. }
  512. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
  513. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  514. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  515. GGML_UNUSED(x_qh);
  516. const int kbx = 0; // threadIdx.x / QI4_K
  517. const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
  518. #pragma unroll
  519. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  520. int i = i0 + threadIdx.y;
  521. if (need_check) {
  522. i = min(i, i_max);
  523. }
  524. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
  525. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  526. }
  527. const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
  528. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  529. #pragma unroll
  530. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
  531. int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  532. if (need_check) {
  533. i = min(i, i_max);
  534. }
  535. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
  536. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
  537. }
  538. #pragma unroll
  539. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  540. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  541. if (need_check) {
  542. i = min(i, i_max);
  543. }
  544. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
  545. const int * scales = (const int *) bxi->scales;
  546. const int ksc = threadIdx.x % (WARP_SIZE/8);
  547. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  548. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  549. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  550. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  551. }
  552. }
  553. template <int mmq_x, int mmq_y, int nwarps>
  554. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
  555. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  556. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  557. GGML_UNUSED(x_qh);
  558. #pragma unroll
  559. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  560. const int j = j0 + threadIdx.y;
  561. #pragma unroll
  562. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  563. const int i = i0 + threadIdx.x;
  564. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
  565. const int index_y = j * WARP_SIZE + (QR4_K*k0) % WARP_SIZE;
  566. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
  567. &x_ql[i * (WARP_SIZE + 1) + k0], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
  568. }
  569. }
  570. }
  571. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
  572. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  573. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  574. GGML_UNUSED(x_qh);
  575. const int kbx = 0; // threadIdx.x / QI5_K
  576. const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
  577. #pragma unroll
  578. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  579. int i = i0 + threadIdx.y;
  580. if (need_check) {
  581. i = min(i, i_max);
  582. }
  583. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
  584. const int ky = QR5_K*kqsx;
  585. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  586. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  587. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  588. const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
  589. const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  590. const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  591. const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
  592. const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
  593. x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  594. x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  595. }
  596. const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
  597. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  598. #pragma unroll
  599. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
  600. int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  601. if (need_check) {
  602. i = min(i, i_max);
  603. }
  604. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
  605. x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
  606. }
  607. #pragma unroll
  608. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  609. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  610. if (need_check) {
  611. i = min(i, i_max);
  612. }
  613. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
  614. const int * scales = (const int *) bxi->scales;
  615. const int ksc = threadIdx.x % (WARP_SIZE/8);
  616. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  617. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  618. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  619. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  620. }
  621. }
  622. template <int mmq_x, int mmq_y, int nwarps>
  623. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
  624. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  625. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  626. GGML_UNUSED(x_qh);
  627. #pragma unroll
  628. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  629. const int j = j0 + threadIdx.y;
  630. #pragma unroll
  631. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  632. const int i = i0 + threadIdx.x;
  633. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  634. const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k0;
  635. const int index_y = j * WARP_SIZE + (QR5_K*k0) % WARP_SIZE;
  636. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
  637. &x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
  638. }
  639. }
  640. }
  641. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
  642. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  643. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  644. GGML_UNUSED(x_qh);
  645. const int kbx = 0; // threadIdx.x / QI6_K
  646. const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
  647. #pragma unroll
  648. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  649. int i = i0 + threadIdx.y;
  650. if (need_check) {
  651. i = min(i, i_max);
  652. }
  653. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
  654. const int ky = QR6_K*kqsx;
  655. const int ql = get_int_from_uint8(bxi->ql, kqsx);
  656. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  657. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  658. const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
  659. const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
  660. const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
  661. const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
  662. const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
  663. x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  664. x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  665. }
  666. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  667. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  668. float * x_dmf = (float *) x_dm;
  669. #pragma unroll
  670. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  671. int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  672. if (need_check) {
  673. i = min(i, i_max);
  674. }
  675. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
  676. x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
  677. }
  678. #pragma unroll
  679. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  680. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  681. if (need_check) {
  682. i = min(i, i_max);
  683. }
  684. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
  685. x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
  686. }
  687. }
  688. template <int mmq_x, int mmq_y, int nwarps>
  689. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
  690. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  691. const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, float * __restrict__ sum, const int & k0) {
  692. GGML_UNUSED(x_qh);
  693. #pragma unroll
  694. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  695. const int j = j0 + threadIdx.y;
  696. #pragma unroll
  697. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  698. const int i = i0 + threadIdx.x;
  699. const float * x_dmf = (const float *) x_dm;
  700. const float * y_df = (const float *) y_ds;
  701. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
  702. const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k0;
  703. const int index_y = j * WARP_SIZE + (QR6_K*k0) % WARP_SIZE;
  704. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
  705. &x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
  706. }
  707. }
  708. }
  709. // -------------------------------------------------------------------------------------------------------------------------------------
  710. template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
  711. struct mmq_type_traits;
  712. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  713. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
  714. static constexpr bool need_sum = true;
  715. static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
  716. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
  717. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  718. };
  719. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  720. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
  721. static constexpr bool need_sum = true;
  722. static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
  723. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
  724. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  725. };
  726. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  727. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
  728. static constexpr bool need_sum = false;
  729. static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
  730. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
  731. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  732. };
  733. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  734. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
  735. static constexpr bool need_sum = true;
  736. static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
  737. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
  738. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  739. };
  740. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  741. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
  742. static constexpr bool need_sum = false;
  743. static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
  744. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
  745. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  746. };
  747. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  748. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
  749. static constexpr bool need_sum = false;
  750. static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
  751. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
  752. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  753. };
  754. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  755. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
  756. static constexpr bool need_sum = false;
  757. static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
  758. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
  759. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  760. };
  761. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  762. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
  763. static constexpr bool need_sum = true;
  764. static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
  765. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
  766. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  767. };
  768. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  769. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
  770. static constexpr bool need_sum = true;
  771. static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
  772. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
  773. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  774. };
  775. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  776. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
  777. static constexpr bool need_sum = false;
  778. static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
  779. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
  780. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  781. };
  782. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  783. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  784. #if defined(RDNA3) || defined(RDNA2)
  785. __launch_bounds__(WARP_SIZE*nwarps, 2)
  786. #endif // defined(RDNA3) || defined(RDNA2)
  787. #else
  788. #if __CUDA_ARCH__ >= CC_VOLTA
  789. __launch_bounds__(WARP_SIZE*nwarps, 1)
  790. #else
  791. __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2)
  792. #endif // __CUDA_ARCH__ >= CC_VOLTA
  793. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  794. static __global__ void mul_mat_q(
  795. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
  796. const int ne00, const int ne01, const int stride00, const int ne10, const int ne11, const int ne0) {
  797. // Skip unused template specializations for faster compilation:
  798. if (mmq_x > get_mmq_x_max_device()) {
  799. NO_DEVICE_CODE;
  800. return;
  801. }
  802. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  803. constexpr int qr = ggml_cuda_type_traits<type>::qr;
  804. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  805. constexpr int mmq_y = get_mmq_y_device(mmq_x);
  806. constexpr bool need_sum = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::need_sum;
  807. constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
  808. constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
  809. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
  810. constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
  811. extern __shared__ char data_mul_mat_q[];
  812. int * tile_x_ql = (int *) data_mul_mat_q;
  813. half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
  814. int * tile_x_qh = (int *) (tile_x_dm + txs.dm);
  815. int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
  816. int * tile_y_qs = (int *) (tile_x_sc + txs.sc); // [mmq_x * WARP_SIZE]
  817. half2 * tile_y_ds = (half2 *) (tile_y_qs + mmq_x*WARP_SIZE); // [mmq_x * WARP_SIZE/QI8_1];
  818. const block_q8_1 * y = (const block_q8_1 *) yc;
  819. const int blocks_per_row_x = ne00 / qk;
  820. const int blocks_per_col_y = ne10 / QK8_1;
  821. const int blocks_per_warp = WARP_SIZE / qi;
  822. const int & ne1 = ne11;
  823. const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
  824. float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
  825. for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
  826. load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride00*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride00);
  827. #pragma unroll
  828. for (int kr = 0; kr < qr; ++kr) {
  829. const int kqs = kr*WARP_SIZE + threadIdx.x;
  830. const int kbxd = kqs / QI8_1;
  831. #pragma unroll
  832. for (int i0 = 0; i0 < mmq_x; i0 += nwarps) {
  833. const int i = min(blockIdx.y*mmq_x + threadIdx.y + i0, ne11-1); // to prevent out-of-bounds memory accesses
  834. const block_q8_1 * by0 = &y[i*blocks_per_col_y + kb0 * (qk/QK8_1) + kbxd];
  835. const int index_y = (i0 + threadIdx.y) * WARP_SIZE + kqs % WARP_SIZE;
  836. tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
  837. }
  838. #pragma unroll
  839. for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
  840. const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
  841. const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
  842. const int i_y_eff = min(blockIdx.y*mmq_x + ids, ne11-1);
  843. // if the sum is not needed it's faster to transform the scale to f32 ahead of time
  844. const half2 * dsi_src = &y[i_y_eff*blocks_per_col_y + kb0 * (qk/QK8_1) + kr*(WARP_SIZE/QI8_1) + kby].ds;
  845. half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
  846. if (need_sum) {
  847. *dsi_dst = *dsi_src;
  848. } else {
  849. float * dfi_dst = (float *) dsi_dst;
  850. *dfi_dst = __low2float(*dsi_src);
  851. }
  852. }
  853. __syncthreads();
  854. // #pragma unroll // unrolling this loop causes too much register pressure
  855. for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
  856. vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, sum, k0);
  857. }
  858. __syncthreads();
  859. }
  860. }
  861. #pragma unroll
  862. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  863. const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
  864. if (j >= ne1) {
  865. return;
  866. }
  867. #pragma unroll
  868. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  869. const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
  870. if (need_check && i >= ne0) {
  871. continue;
  872. }
  873. dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  874. }
  875. }
  876. }
  877. struct mmq_args {
  878. const char * x; const char * y; float * dst;
  879. int64_t ne00; int64_t ne01; int64_t stride00;
  880. int64_t ne10; int64_t ne11;
  881. int64_t ne0;
  882. };
  883. template <ggml_type type, int mmq_x, int nwarps>
  884. static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
  885. const int id = ggml_cuda_get_device();
  886. const int cc = ggml_cuda_info().devices[id].cc;
  887. const int mmq_y = get_mmq_y_host(cc, mmq_x);
  888. const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
  889. const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
  890. const dim3 block_nums(block_num_x, block_num_y, 1);
  891. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  892. const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
  893. const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
  894. const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
  895. const int shmem = shmem_x + shmem_y;
  896. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  897. static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
  898. if (!shmem_limit_raised[id]) {
  899. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  900. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  901. shmem_limit_raised[id] = true;
  902. }
  903. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  904. if (args.ne01 % mmq_y == 0) {
  905. const bool need_check = false;
  906. mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
  907. (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
  908. } else {
  909. const bool need_check = true;
  910. mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
  911. (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride00, args.ne10, args.ne11, args.ne0);
  912. }
  913. }
  914. template <ggml_type type>
  915. void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
  916. const int id = ggml_cuda_get_device();
  917. const int nsm = ggml_cuda_info().devices[id].nsm;
  918. const int cc = ggml_cuda_info().devices[id].cc;
  919. const int mmq_x_max = get_mmq_x_max_host(cc);
  920. const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
  921. const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
  922. int mmq_x_best = 0;
  923. int nwaves_best = INT_MAX;
  924. for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
  925. const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
  926. const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
  927. if (nwaves < nwaves_best) {
  928. mmq_x_best = mmq_x;
  929. nwaves_best = nwaves;
  930. }
  931. }
  932. switch (mmq_x_best) {
  933. case 8:
  934. launch_mul_mat_q<type, 8, 4>(args, stream);
  935. break;
  936. case 16:
  937. launch_mul_mat_q<type, 16, 8>(args, stream);
  938. break;
  939. case 24:
  940. launch_mul_mat_q<type, 24, 8>(args, stream);
  941. break;
  942. case 32:
  943. launch_mul_mat_q<type, 32, 8>(args, stream);
  944. break;
  945. case 40:
  946. launch_mul_mat_q<type, 40, 8>(args, stream);
  947. break;
  948. case 48:
  949. launch_mul_mat_q<type, 48, 8>(args, stream);
  950. break;
  951. case 56:
  952. launch_mul_mat_q<type, 56, 8>(args, stream);
  953. break;
  954. case 64:
  955. launch_mul_mat_q<type, 64, 8>(args, stream);
  956. break;
  957. case 72:
  958. launch_mul_mat_q<type, 72, 8>(args, stream);
  959. break;
  960. case 80:
  961. launch_mul_mat_q<type, 80, 8>(args, stream);
  962. break;
  963. case 88:
  964. launch_mul_mat_q<type, 88, 8>(args, stream);
  965. break;
  966. case 96:
  967. launch_mul_mat_q<type, 96, 8>(args, stream);
  968. break;
  969. case 104:
  970. launch_mul_mat_q<type, 104, 8>(args, stream);
  971. break;
  972. case 112:
  973. launch_mul_mat_q<type, 112, 8>(args, stream);
  974. break;
  975. case 120:
  976. launch_mul_mat_q<type, 120, 8>(args, stream);
  977. break;
  978. case 128:
  979. launch_mul_mat_q<type, 128, 8>(args, stream);
  980. break;
  981. default:
  982. GGML_ASSERT(false);
  983. break;
  984. }
  985. }
  986. #define DECL_MMQ_CASE(type) \
  987. template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
  988. extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
  989. extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
  990. extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
  991. extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
  992. extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
  993. extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
  994. extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
  995. extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
  996. extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
  997. extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
  998. // -------------------------------------------------------------------------------------------------------------------------
  999. void ggml_cuda_op_mul_mat_q(
  1000. ggml_backend_cuda_context & ctx,
  1001. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  1002. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  1003. const int64_t src1_padded_row_size, cudaStream_t stream);
  1004. bool ggml_cuda_supports_mmq(enum ggml_type type);