mmq.cuh 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331
  1. #pragma once
  2. #include "common.cuh"
  3. #include "vecdotq.cuh"
  4. #include <climits>
  5. #include <cstdint>
  6. #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
  7. typedef void (*load_tiles_mmq_t)(
  8. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  9. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride);
  10. typedef void (*vec_dot_mmq_t)(
  11. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  12. const int * __restrict__ y, float * __restrict__ sum, const int & k0);
  13. struct block_q8_1_mmq {
  14. half2 ds[4];
  15. int8_t qs[4*QK8_1];
  16. };
  17. static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
  18. static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
  19. struct tile_x_sizes {
  20. int ql;
  21. int dm;
  22. int qh;
  23. int sc;
  24. };
  25. // get_mmq_x_max_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  26. static constexpr __device__ int get_mmq_x_max_device() {
  27. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  28. return 64;
  29. #else
  30. #if __CUDA_ARCH__ >= CC_VOLTA
  31. #ifdef CUDA_USE_TENSOR_CORES
  32. return MMQ_MAX_BATCH_SIZE;
  33. #else
  34. return 128;
  35. #endif // CUDA_USE_TENSOR_CORES
  36. #else
  37. return 64;
  38. #endif // __CUDA_ARCH__ >= CC_VOLTA
  39. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  40. }
  41. // get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row
  42. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  43. static constexpr __device__ int get_mmq_y_device(int mmq_x) {
  44. return mmq_x >= 32 ? 128 : 64;
  45. }
  46. #else
  47. #if __CUDA_ARCH__ >= CC_VOLTA
  48. static constexpr __device__ int get_mmq_y_device(int mmq_x) {
  49. return mmq_x >= 32 ? 128 : 64;
  50. }
  51. #else
  52. static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) {
  53. return 64;
  54. }
  55. #endif // __CUDA_ARCH__ >= CC_VOLTA
  56. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  57. #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0, 0}
  58. #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0, 0}
  59. #define TILE_X_SIZES_Q5_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_0 + mmq_y/QI5_0, 0, 0}
  60. #define TILE_X_SIZES_Q5_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_1 + mmq_y/QI5_1, 0, 0}
  61. #define TILE_X_SIZES_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI8_0 + mmq_y/QI8_0, 0, 0}
  62. #define TILE_X_SIZES_Q2_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI2_K + mmq_y/QI2_K, 0, mmq_y*WARP_SIZE/4 + mmq_y/4}
  63. #define TILE_X_SIZES_Q3_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI3_K + mmq_y/QI3_K, mmq_y*WARP_SIZE/2 + mmq_y/2, mmq_y*WARP_SIZE/4 + mmq_y/4}
  64. #define TILE_X_SIZES_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
  65. #define TILE_X_SIZES_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
  66. #define TILE_X_SIZES_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, 0, mmq_y*WARP_SIZE/8 + mmq_y/8}
  67. #define GET_TILE_X_SIZES_BODY \
  68. return type == GGML_TYPE_Q4_0 ? TILE_X_SIZES_Q4_0 : \
  69. type == GGML_TYPE_Q4_1 ? TILE_X_SIZES_Q4_1 : \
  70. type == GGML_TYPE_Q5_0 ? TILE_X_SIZES_Q5_0 : \
  71. type == GGML_TYPE_Q5_1 ? TILE_X_SIZES_Q5_1 : \
  72. type == GGML_TYPE_Q8_0 ? TILE_X_SIZES_Q8_0 : \
  73. type == GGML_TYPE_Q2_K ? TILE_X_SIZES_Q2_K : \
  74. type == GGML_TYPE_Q3_K ? TILE_X_SIZES_Q3_K : \
  75. type == GGML_TYPE_Q4_K ? TILE_X_SIZES_Q4_K : \
  76. type == GGML_TYPE_Q5_K ? TILE_X_SIZES_Q5_K : \
  77. type == GGML_TYPE_Q6_K ? TILE_X_SIZES_Q6_K : \
  78. tile_x_sizes{0, 0, 0, 0}
  79. static tile_x_sizes get_tile_x_sizes_host(const ggml_type type, const int mmq_y) {
  80. GET_TILE_X_SIZES_BODY;
  81. }
  82. template <int mmq_y>
  83. static constexpr __device__ tile_x_sizes get_tile_x_sizes_device(ggml_type type) {
  84. GET_TILE_X_SIZES_BODY;
  85. }
  86. // ------------------------------------------------------------
  87. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
  88. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  89. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  90. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  91. const int kbx = threadIdx.x / QI4_0;
  92. const int kqsx = threadIdx.x % QI4_0;
  93. float * x_dmf = (float *) x_dm;
  94. #pragma unroll
  95. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  96. int i = i0 + threadIdx.y;
  97. if (need_check) {
  98. i = min(i, i_max);
  99. }
  100. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx;
  101. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  102. }
  103. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  104. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  105. #pragma unroll
  106. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  107. int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row;
  108. if (need_check) {
  109. i = min(i, i_max);
  110. }
  111. const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd;
  112. x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
  113. }
  114. }
  115. template <int mmq_x, int mmq_y, int nwarps>
  116. static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
  117. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  118. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  119. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  120. const float * x_dmf = (const float *) x_dm;
  121. const int * y_qs = (const int *) y + 4;
  122. const half2 * y_ds = (const half2 *) y;
  123. #pragma unroll
  124. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  125. const int j = j0 + threadIdx.y;
  126. #pragma unroll
  127. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  128. const int i = i0 + threadIdx.x;
  129. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  130. int u[2*VDR_Q4_0_Q8_1_MMQ];
  131. #pragma unroll
  132. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  133. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  134. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_0) % WARP_SIZE];
  135. }
  136. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  137. (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
  138. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  139. }
  140. }
  141. }
  142. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
  143. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  144. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  145. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  146. const int kbx = threadIdx.x / QI4_1;
  147. const int kqsx = threadIdx.x % QI4_1;
  148. #pragma unroll
  149. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  150. int i = i0 + threadIdx.y;
  151. if (need_check) {
  152. i = min(i, i_max);
  153. }
  154. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx;
  155. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  156. }
  157. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  158. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  159. #pragma unroll
  160. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  161. int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row;
  162. if (need_check) {
  163. i = min(i, i_max);
  164. }
  165. const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd;
  166. x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
  167. }
  168. }
  169. template <int mmq_x, int mmq_y, int nwarps>
  170. static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
  171. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  172. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  173. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  174. const int * y_qs = (const int *) y + 4;
  175. const half2 * y_ds = (const half2 *) y;
  176. #pragma unroll
  177. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  178. const int j = j0 + threadIdx.y;
  179. #pragma unroll
  180. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  181. const int i = i0 + threadIdx.x;
  182. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  183. int u[2*VDR_Q4_1_Q8_1_MMQ];
  184. #pragma unroll
  185. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  186. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  187. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI4_1) % WARP_SIZE];
  188. }
  189. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  190. (&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/QI4_1],
  191. y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  192. }
  193. }
  194. }
  195. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
  196. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  197. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  198. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  199. const int kbx = threadIdx.x / QI5_0;
  200. const int kqsx = threadIdx.x % QI5_0;
  201. #pragma unroll
  202. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  203. int i = i0 + threadIdx.y;
  204. if (need_check) {
  205. i = min(i, i_max);
  206. }
  207. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx;
  208. const int ql = get_int_from_uint8(bxi->qs, kqsx);
  209. const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0));
  210. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  211. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  212. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  213. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  214. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  215. qs0 = __vsubss4(qs0, 0x10101010); // subtract 16
  216. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  217. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  218. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  219. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  220. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  221. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  222. qs1 = __vsubss4(qs1, 0x10101010); // subtract 16
  223. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  224. }
  225. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  226. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  227. float * x_dmf = (float *) x_dm;
  228. #pragma unroll
  229. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  230. int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row;
  231. if (need_check) {
  232. i = min(i, i_max);
  233. }
  234. const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd;
  235. x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
  236. }
  237. }
  238. template <int mmq_x, int mmq_y, int nwarps>
  239. static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
  240. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  241. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  242. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  243. const float * x_dmf = (const float *) x_dm;
  244. const int * y_qs = (const int *) y + 4;
  245. const float * y_df = (const float *) y;
  246. #pragma unroll
  247. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  248. const int j = j0 + threadIdx.y;
  249. #pragma unroll
  250. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  251. const int i = i0 + threadIdx.x;
  252. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  253. const int index_bx = i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0;
  254. int u[2*VDR_Q5_0_Q8_1_MMQ];
  255. #pragma unroll
  256. for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
  257. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  258. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_0) % WARP_SIZE];
  259. }
  260. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, QR5_0*VDR_Q5_0_Q8_1_MMQ>
  261. (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dmf[index_bx], y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  262. }
  263. }
  264. }
  265. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
  266. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  267. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  268. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  269. const int kbx = threadIdx.x / QI5_1;
  270. const int kqsx = threadIdx.x % QI5_1;
  271. #pragma unroll
  272. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  273. int i = i0 + threadIdx.y;
  274. if (need_check) {
  275. i = min(i, i_max);
  276. }
  277. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx;
  278. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  279. const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1));
  280. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  281. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  282. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  283. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  284. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  285. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+0] = qs0;
  286. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  287. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  288. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  289. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  290. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  291. x_ql[i * (2*WARP_SIZE + 1) + 2*threadIdx.x+1] = qs1;
  292. }
  293. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  294. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  295. #pragma unroll
  296. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  297. int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row;
  298. if (need_check) {
  299. i = min(i, i_max);
  300. }
  301. const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd;
  302. x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
  303. }
  304. }
  305. template <int mmq_x, int mmq_y, int nwarps>
  306. static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
  307. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  308. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  309. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  310. const int * y_qs = (const int *) y + 4;
  311. const half2 * y_ds = (const half2 *) y;
  312. #pragma unroll
  313. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  314. const int j = j0 + threadIdx.y;
  315. #pragma unroll
  316. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  317. const int i = i0 + threadIdx.x;
  318. const int kyqs = k0 % (QI8_1/2) + QI8_1 * (k0 / (QI8_1/2));
  319. const int index_bx = i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1;
  320. int u[2*VDR_Q5_1_Q8_1_MMQ];
  321. #pragma unroll
  322. for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
  323. u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l) % WARP_SIZE];
  324. u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + (kyqs + l + QI5_1) % WARP_SIZE];
  325. }
  326. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  327. (&x_ql[i*(2*WARP_SIZE + 1) + 2*k0], u, x_dm[index_bx], y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
  328. }
  329. }
  330. }
  331. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
  332. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  333. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  334. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  335. const int kbx = threadIdx.x / QI8_0;
  336. const int kqsx = threadIdx.x % QI8_0;
  337. float * x_dmf = (float *) x_dm;
  338. #pragma unroll
  339. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  340. int i = i0 + threadIdx.y;
  341. if (need_check) {
  342. i = min(i, i_max);
  343. }
  344. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx;
  345. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_int8(bxi->qs, kqsx);
  346. }
  347. const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
  348. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  349. #pragma unroll
  350. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
  351. int i = i0 + threadIdx.y * QI8_0 + threadIdx.x / blocks_per_tile_x_row;
  352. if (need_check) {
  353. i = min(i, i_max);
  354. }
  355. const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd;
  356. x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
  357. }
  358. }
  359. template <int mmq_x, int mmq_y, int nwarps>
  360. static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
  361. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  362. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  363. GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
  364. const float * x_dmf = (const float *) x_dm;
  365. const int * y_qs = (const int *) y + 4;
  366. const float * y_df = (const float *) y;
  367. #pragma unroll
  368. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  369. const int j = j0 + threadIdx.y;
  370. #pragma unroll
  371. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  372. const int i = i0 + threadIdx.x;
  373. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl<float, VDR_Q8_0_Q8_1_MMQ>
  374. (&x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0], x_dmf[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0],
  375. y_df[j*MMQ_TILE_Y_K + k0/QI8_1]);
  376. }
  377. }
  378. }
  379. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
  380. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  381. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  382. GGML_UNUSED(x_qh);
  383. const int kbx = threadIdx.x / QI2_K;
  384. const int kqsx = threadIdx.x % QI2_K;
  385. #pragma unroll
  386. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  387. int i = i0 + threadIdx.y;
  388. if (need_check) {
  389. i = min(i, i_max);
  390. }
  391. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbx;
  392. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  393. }
  394. const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
  395. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  396. #pragma unroll
  397. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
  398. int i = (i0 + threadIdx.y * QI2_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  399. if (need_check) {
  400. i = min(i, i_max);
  401. }
  402. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + kbxd;
  403. x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
  404. }
  405. #pragma unroll
  406. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  407. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  408. if (need_check) {
  409. i = min(i, i_max);
  410. }
  411. const block_q2_K * bxi = (const block_q2_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI2_K/4);
  412. x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, threadIdx.x % (QI2_K/4));
  413. }
  414. }
  415. template <int mmq_x, int mmq_y, int nwarps>
  416. static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mul_mat(
  417. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  418. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  419. GGML_UNUSED(x_qh);
  420. const int * y_qs = (const int *) y + 4;
  421. const float * y_df = (const float *) y;
  422. #pragma unroll
  423. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  424. const int j = j0 + threadIdx.y;
  425. #pragma unroll
  426. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  427. const int i = i0 + threadIdx.x;
  428. const int kbx = k0 / QI2_K;
  429. const int ky = (k0 % QI2_K) * QR2_K;
  430. int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
  431. const int kqsx = i*(WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
  432. const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
  433. #pragma unroll
  434. for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
  435. v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
  436. }
  437. const uint8_t * scales = ((const uint8_t *) &x_sc[i*(WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
  438. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq(
  439. v, &y_qs[j*MMQ_TILE_Y_K + (QR2_K*k0) % WARP_SIZE], scales,
  440. x_dm[i*(WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[j*MMQ_TILE_Y_K + ((QR2_K*k0) % WARP_SIZE)/QI8_1]);
  441. }
  442. }
  443. }
  444. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
  445. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  446. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  447. const int kbx = threadIdx.x / QI3_K;
  448. const int kqsx = threadIdx.x % QI3_K;
  449. #pragma unroll
  450. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  451. int i = i0 + threadIdx.y;
  452. if (need_check) {
  453. i = min(i, i_max);
  454. }
  455. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbx;
  456. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8(bxi->qs, kqsx);
  457. }
  458. const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
  459. const int kbxd = threadIdx.x % blocks_per_tile_x_row;
  460. float * x_dmf = (float *) x_dm;
  461. #pragma unroll
  462. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
  463. int i = (i0 + threadIdx.y * QI3_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  464. if (need_check) {
  465. i = min(i, i_max);
  466. }
  467. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + kbxd;
  468. x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
  469. }
  470. #pragma unroll
  471. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
  472. int i = i0 + threadIdx.y * 2 + threadIdx.x / (WARP_SIZE/2);
  473. if (need_check) {
  474. i = min(i, i_max);
  475. }
  476. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/2)) / (QI3_K/2);
  477. // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
  478. x_qh[i * (WARP_SIZE/2) + i / 2 + threadIdx.x % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, threadIdx.x % (QI3_K/2));
  479. }
  480. #pragma unroll
  481. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  482. int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
  483. if (need_check) {
  484. i = min(i, i_max);
  485. }
  486. const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/4)) / (QI3_K/4);
  487. const int ksc = threadIdx.x % (QI3_K/4);
  488. const int ksc_low = ksc % (QI3_K/8);
  489. const int shift_low = 4 * (ksc / (QI3_K/8));
  490. const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  491. const int ksc_high = QI3_K/8;
  492. const int shift_high = 2 * ksc;
  493. const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  494. const int sc = __vsubss4(sc_low | sc_high, 0x20202020);
  495. x_sc[i * (WARP_SIZE/4) + i / 4 + threadIdx.x % (WARP_SIZE/4)] = sc;
  496. }
  497. }
  498. template <int mmq_x, int mmq_y, int nwarps>
  499. static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mul_mat(
  500. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  501. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  502. const float * x_dmf = (const float *) x_dm;
  503. const int * y_qs = (const int *) y + 4;
  504. const float * y_df = (const float *) y;
  505. #pragma unroll
  506. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  507. const int j = j0 + threadIdx.y;
  508. #pragma unroll
  509. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  510. const int i = i0 + threadIdx.x;
  511. const int kbx = k0 / QI3_K;
  512. const int ky = (k0 % QI3_K) * QR3_K;
  513. const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  514. int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
  515. #pragma unroll
  516. for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
  517. const int kqsx = i*(WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
  518. const int shift = 2 * ((ky % 32) / 8);
  519. const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
  520. const int vh = x_qh[i*(WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
  521. const int vlh = (vh << 2) & 0x04040404;
  522. v[l] = __vsubss4(vll, vlh);
  523. }
  524. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq(
  525. v, &y_qs[j*MMQ_TILE_Y_K + (k0*QR3_K) % WARP_SIZE], scales,
  526. x_dmf[i*(WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[j*MMQ_TILE_Y_K + ((k0*QR3_K) % WARP_SIZE)/QI8_1]);
  527. }
  528. }
  529. }
  530. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
  531. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  532. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  533. GGML_UNUSED(x_qh);
  534. const int kbx = 0; // threadIdx.x / QI4_K
  535. const int kqsx = threadIdx.x; // threadIdx.x % QI4_K
  536. #pragma unroll
  537. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  538. int i = i0 + threadIdx.y;
  539. if (need_check) {
  540. i = min(i, i_max);
  541. }
  542. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx;
  543. x_ql[i * (WARP_SIZE + 1) + threadIdx.x] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  544. }
  545. const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
  546. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  547. #pragma unroll
  548. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
  549. int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  550. if (need_check) {
  551. i = min(i, i_max);
  552. }
  553. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd;
  554. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
  555. }
  556. #pragma unroll
  557. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  558. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  559. if (need_check) {
  560. i = min(i, i_max);
  561. }
  562. const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8);
  563. const int * scales = (const int *) bxi->scales;
  564. const int ksc = threadIdx.x % (WARP_SIZE/8);
  565. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  566. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  567. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  568. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  569. }
  570. }
  571. template <int mmq_x, int mmq_y, int nwarps>
  572. static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mul_mat(
  573. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  574. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  575. GGML_UNUSED(x_qh);
  576. const int * y_qs = (const int *) y + 4;
  577. const half2 * y_ds = (const half2 *) y;
  578. #pragma unroll
  579. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  580. const int j = j0 + threadIdx.y;
  581. #pragma unroll
  582. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  583. const int i = i0 + threadIdx.x;
  584. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2*((k0 % 16) / 8);
  585. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq(
  586. &x_ql[i*(WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + (QR4_K*k0) % WARP_SIZE], sc, sc+8,
  587. x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + ((QR4_K*k0) % WARP_SIZE)/QI8_1]);
  588. }
  589. }
  590. }
  591. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
  592. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  593. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  594. GGML_UNUSED(x_qh);
  595. const int kbx = 0; // threadIdx.x / QI5_K
  596. const int kqsx = threadIdx.x; // threadIdx.x % QI5_K
  597. #pragma unroll
  598. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  599. int i = i0 + threadIdx.y;
  600. if (need_check) {
  601. i = min(i, i_max);
  602. }
  603. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx;
  604. const int ky = QR5_K*kqsx;
  605. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  606. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  607. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  608. const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
  609. const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  610. const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  611. const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0;
  612. const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4);
  613. x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  614. x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  615. }
  616. const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
  617. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  618. #pragma unroll
  619. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
  620. int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  621. if (need_check) {
  622. i = min(i, i_max);
  623. }
  624. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd;
  625. x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
  626. }
  627. #pragma unroll
  628. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  629. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  630. if (need_check) {
  631. i = min(i, i_max);
  632. }
  633. const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8);
  634. const int * scales = (const int *) bxi->scales;
  635. const int ksc = threadIdx.x % (WARP_SIZE/8);
  636. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  637. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  638. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  639. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  640. }
  641. }
  642. template <int mmq_x, int mmq_y, int nwarps>
  643. static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mul_mat(
  644. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  645. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  646. GGML_UNUSED(x_qh);
  647. const int * y_qs = (const int *) y + 4;
  648. const half2 * y_ds = (const half2 *) y;
  649. #pragma unroll
  650. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  651. const int j = j0 + threadIdx.y;
  652. #pragma unroll
  653. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  654. const int i = i0 + threadIdx.x;
  655. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]) + 2 * ((k0 % 16) / 8);
  656. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq(
  657. &x_ql[i*(QR5_K*WARP_SIZE + 1) + QR5_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR5_K*k0) % WARP_SIZE], sc, sc+8,
  658. x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + ((QR5_K*k0) % WARP_SIZE)/QI8_1]);
  659. }
  660. }
  661. }
  662. template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
  663. const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
  664. int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
  665. GGML_UNUSED(x_qh);
  666. const int kbx = 0; // threadIdx.x / QI6_K
  667. const int kqsx = threadIdx.x; // threadIdx.x % QI6_K
  668. #pragma unroll
  669. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  670. int i = i0 + threadIdx.y;
  671. if (need_check) {
  672. i = min(i, i_max);
  673. }
  674. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx;
  675. const int ky = QR6_K*kqsx;
  676. const int ql = get_int_from_uint8(bxi->ql, kqsx);
  677. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  678. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  679. const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
  680. const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
  681. const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
  682. const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0;
  683. const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2);
  684. x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020);
  685. x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020);
  686. }
  687. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  688. const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256
  689. float * x_dmf = (float *) x_dm;
  690. #pragma unroll
  691. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  692. int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y;
  693. if (need_check) {
  694. i = min(i, i_max);
  695. }
  696. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd;
  697. x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
  698. }
  699. #pragma unroll
  700. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  701. int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y;
  702. if (need_check) {
  703. i = min(i, i_max);
  704. }
  705. const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4;
  706. x_sc[i * (WARP_SIZE/8) + i / 8 + threadIdx.x % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, threadIdx.x % (QI6_K/8));
  707. }
  708. }
  709. template <int mmq_x, int mmq_y, int nwarps>
  710. static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
  711. const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
  712. const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
  713. GGML_UNUSED(x_qh);
  714. const float * x_dmf = (const float *) x_dm;
  715. const int * y_qs = (const int *) y + 4;
  716. const float * y_df = (const float *) y;
  717. #pragma unroll
  718. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  719. const int j = j0 + threadIdx.y;
  720. #pragma unroll
  721. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  722. const int i = i0 + threadIdx.x;
  723. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/8]);
  724. sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq(
  725. &x_ql[i*(QR6_K*WARP_SIZE + 1) + QR6_K*k0], &y_qs[j*MMQ_TILE_Y_K + (QR6_K*k0) % WARP_SIZE], sc,
  726. x_dmf[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + ((QR6_K*k0) % WARP_SIZE)/QI8_1]);
  727. }
  728. }
  729. }
  730. // -------------------------------------------------------------------------------------------------------------------------------------
  731. template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
  732. struct mmq_type_traits;
  733. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  734. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
  735. static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
  736. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
  737. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  738. };
  739. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  740. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
  741. static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
  742. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
  743. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  744. };
  745. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  746. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
  747. static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
  748. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
  749. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  750. };
  751. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  752. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
  753. static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
  754. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
  755. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  756. };
  757. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  758. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
  759. static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
  760. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
  761. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  762. };
  763. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  764. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
  765. static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
  766. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
  767. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  768. };
  769. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  770. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
  771. static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
  772. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
  773. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  774. };
  775. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  776. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
  777. static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
  778. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
  779. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  780. };
  781. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  782. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
  783. static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
  784. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
  785. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  786. };
  787. template <int mmq_x, int mmq_y, int nwarps, bool need_check>
  788. struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
  789. static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
  790. static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
  791. static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
  792. };
  793. static int mmq_need_sum(const ggml_type type_x) {
  794. switch (type_x) {
  795. case GGML_TYPE_Q4_0:
  796. case GGML_TYPE_Q4_1:
  797. return true;
  798. case GGML_TYPE_Q5_0:
  799. return false;
  800. case GGML_TYPE_Q5_1:
  801. return true;
  802. case GGML_TYPE_Q8_0:
  803. case GGML_TYPE_Q2_K:
  804. case GGML_TYPE_Q3_K:
  805. return false;
  806. case GGML_TYPE_Q4_K:
  807. case GGML_TYPE_Q5_K:
  808. return true;
  809. case GGML_TYPE_Q6_K:
  810. return false;
  811. default:
  812. GGML_ASSERT(false);
  813. break;
  814. }
  815. return false;
  816. }
  817. template <ggml_type type, int mmq_x, int nwarps, bool need_check>
  818. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  819. #if defined(RDNA3) || defined(RDNA2)
  820. __launch_bounds__(WARP_SIZE*nwarps, 2)
  821. #endif // defined(RDNA3) || defined(RDNA2)
  822. #else
  823. #if __CUDA_ARCH__ >= CC_VOLTA
  824. __launch_bounds__(WARP_SIZE*nwarps, 1)
  825. #else
  826. __launch_bounds__(WARP_SIZE*nwarps, type == GGML_TYPE_Q2_K ? 1 : 2)
  827. #endif // __CUDA_ARCH__ >= CC_VOLTA
  828. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  829. static __global__ void mul_mat_q(
  830. const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst,
  831. const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
  832. // Skip unused template specializations for faster compilation:
  833. if (mmq_x > get_mmq_x_max_device()) {
  834. NO_DEVICE_CODE;
  835. return;
  836. }
  837. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  838. constexpr int qr = ggml_cuda_type_traits<type>::qr;
  839. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  840. constexpr int mmq_y = get_mmq_y_device(mmq_x);
  841. constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
  842. constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
  843. constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
  844. constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
  845. extern __shared__ char data_mul_mat_q[];
  846. int * tile_x_ql = (int *) data_mul_mat_q;
  847. half2 * tile_x_dm = (half2 *) (tile_x_ql + txs.ql);
  848. int * tile_x_qh = (int *) (tile_x_dm + txs.dm);
  849. int * tile_x_sc = (int *) (tile_x_qh + txs.qh);
  850. int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)]
  851. const int blocks_per_row_x = ne00 / qk;
  852. const int blocks_per_warp = WARP_SIZE / qi;
  853. const int & ne1 = ne11;
  854. const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1;
  855. const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
  856. float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
  857. for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
  858. load_tiles(x, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01);
  859. #pragma unroll
  860. for (int kr = 0; kr < qr; ++kr) {
  861. const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + kr*sizeof(block_q8_1_mmq)/sizeof(int));
  862. #pragma unroll
  863. for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
  864. int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
  865. tile_y[l] = by0[l];
  866. }
  867. __syncthreads();
  868. // #pragma unroll // unrolling this loop causes too much register pressure
  869. for (int k0 = kr*WARP_SIZE/qr; k0 < (kr+1)*WARP_SIZE/qr; k0 += vdr) {
  870. vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y, sum, k0);
  871. }
  872. __syncthreads();
  873. }
  874. }
  875. #pragma unroll
  876. for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
  877. const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
  878. if (j >= ne1) {
  879. return;
  880. }
  881. #pragma unroll
  882. for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
  883. const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
  884. if (need_check && i >= ne0) {
  885. continue;
  886. }
  887. dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
  888. }
  889. }
  890. }
  891. struct mmq_args {
  892. const char * x; const char * y; float * dst;
  893. int64_t ne00; int64_t ne01; int64_t stride01;
  894. int64_t ne10; int64_t ne11; int64_t stride11;
  895. int64_t ne0;
  896. };
  897. template <ggml_type type, int mmq_x, int nwarps>
  898. static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) {
  899. const int id = ggml_cuda_get_device();
  900. const int cc = ggml_cuda_info().devices[id].cc;
  901. const int mmq_y = get_mmq_y_host(cc, mmq_x);
  902. const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y;
  903. const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x;
  904. const dim3 block_nums(block_num_x, block_num_y, 1);
  905. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  906. const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y);
  907. const int shmem_x = txs.ql*sizeof(int) + txs.dm*sizeof(half2) + txs.qh*sizeof(int) + txs.sc*sizeof(int);
  908. const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2);
  909. const int shmem = shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int));
  910. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  911. static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
  912. if (!shmem_limit_raised[id]) {
  913. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  914. CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, nwarps, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
  915. shmem_limit_raised[id] = true;
  916. }
  917. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  918. if (args.ne01 % mmq_y == 0) {
  919. const bool need_check = false;
  920. mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
  921. (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  922. } else {
  923. const bool need_check = true;
  924. mul_mat_q<type, mmq_x, nwarps, need_check><<<block_nums, block_dims, shmem, stream>>>
  925. (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
  926. }
  927. }
  928. template <ggml_type type>
  929. void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
  930. const int id = ggml_cuda_get_device();
  931. const int nsm = ggml_cuda_info().devices[id].nsm;
  932. const int cc = ggml_cuda_info().devices[id].cc;
  933. const int mmq_x_max = get_mmq_x_max_host(cc);
  934. const int mmq_y = get_mmq_y_host(cc, mmq_x_max);
  935. const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
  936. int mmq_x_best = 0;
  937. int nwaves_best = INT_MAX;
  938. for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) {
  939. const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x;
  940. const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm;
  941. if (nwaves < nwaves_best) {
  942. mmq_x_best = mmq_x;
  943. nwaves_best = nwaves;
  944. }
  945. }
  946. switch (mmq_x_best) {
  947. case 8:
  948. launch_mul_mat_q<type, 8, 4>(args, stream);
  949. break;
  950. case 16:
  951. launch_mul_mat_q<type, 16, 8>(args, stream);
  952. break;
  953. case 24:
  954. launch_mul_mat_q<type, 24, 8>(args, stream);
  955. break;
  956. case 32:
  957. launch_mul_mat_q<type, 32, 8>(args, stream);
  958. break;
  959. case 40:
  960. launch_mul_mat_q<type, 40, 8>(args, stream);
  961. break;
  962. case 48:
  963. launch_mul_mat_q<type, 48, 8>(args, stream);
  964. break;
  965. case 56:
  966. launch_mul_mat_q<type, 56, 8>(args, stream);
  967. break;
  968. case 64:
  969. launch_mul_mat_q<type, 64, 8>(args, stream);
  970. break;
  971. case 72:
  972. launch_mul_mat_q<type, 72, 8>(args, stream);
  973. break;
  974. case 80:
  975. launch_mul_mat_q<type, 80, 8>(args, stream);
  976. break;
  977. case 88:
  978. launch_mul_mat_q<type, 88, 8>(args, stream);
  979. break;
  980. case 96:
  981. launch_mul_mat_q<type, 96, 8>(args, stream);
  982. break;
  983. case 104:
  984. launch_mul_mat_q<type, 104, 8>(args, stream);
  985. break;
  986. case 112:
  987. launch_mul_mat_q<type, 112, 8>(args, stream);
  988. break;
  989. case 120:
  990. launch_mul_mat_q<type, 120, 8>(args, stream);
  991. break;
  992. case 128:
  993. launch_mul_mat_q<type, 128, 8>(args, stream);
  994. break;
  995. default:
  996. GGML_ASSERT(false);
  997. break;
  998. }
  999. }
  1000. #define DECL_MMQ_CASE(type) \
  1001. template void mul_mat_q_case<type>(const mmq_args & args, cudaStream_t stream) \
  1002. extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
  1003. extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
  1004. extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
  1005. extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
  1006. extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
  1007. extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
  1008. extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
  1009. extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
  1010. extern DECL_MMQ_CASE(GGML_TYPE_Q5_K);
  1011. extern DECL_MMQ_CASE(GGML_TYPE_Q6_K);
  1012. // -------------------------------------------------------------------------------------------------------------------------
  1013. void ggml_cuda_op_mul_mat_q(
  1014. ggml_backend_cuda_context & ctx,
  1015. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  1016. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  1017. const int64_t src1_padded_row_size, cudaStream_t stream);
  1018. bool ggml_cuda_supports_mmq(enum ggml_type type);