mmvq.cu 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. #include "mmvq.cuh"
  2. #include "vecdotq.cuh"
  3. typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
  4. static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
  5. return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
  6. type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
  7. type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
  8. type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
  9. type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
  10. type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
  11. type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
  12. type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
  13. type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
  14. type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
  15. type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
  16. type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
  17. type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
  18. type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
  19. type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
  20. type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
  21. type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
  22. type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
  23. type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
  24. nullptr;
  25. }
  26. static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
  27. return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
  28. type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
  29. type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
  30. type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
  31. type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
  32. type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
  33. type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
  34. type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
  35. type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
  36. type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
  37. type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ :
  38. 1;
  39. }
  40. template <ggml_type type, int ncols_y>
  41. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  42. // tell the compiler to use as many registers as it wants, see nwarps definition below
  43. __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
  44. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  45. static __global__ void mul_mat_vec_q(
  46. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  47. const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
  48. constexpr int qk = ggml_cuda_type_traits<type>::qk;
  49. constexpr int qi = ggml_cuda_type_traits<type>::qi;
  50. constexpr int vdr = get_vdr_mmvq(type);
  51. constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
  52. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
  53. constexpr int nwarps = 1;
  54. constexpr int rows_per_cuda_block = 1;
  55. #else
  56. constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
  57. constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
  58. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
  59. const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
  60. const int row0 = rows_per_cuda_block*blockIdx.x;
  61. const int blocks_per_row_x = ncols_x / qk;
  62. const int blocks_per_col_y = nrows_y / QK8_1;
  63. constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
  64. // partial sum for each thread
  65. float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
  66. const block_q8_1 * y = (const block_q8_1 *) vy;
  67. for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
  68. const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
  69. // x block quant index when casting the quants to int
  70. const int kqs = vdr * (tid % (qi/vdr));
  71. #pragma unroll
  72. for (int j = 0; j < ncols_y; ++j) {
  73. #pragma unroll
  74. for (int i = 0; i < rows_per_cuda_block; ++i) {
  75. tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
  76. }
  77. }
  78. }
  79. __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
  80. if (threadIdx.y > 0) {
  81. #pragma unroll
  82. for (int j = 0; j < ncols_y; ++j) {
  83. #pragma unroll
  84. for (int i = 0; i < rows_per_cuda_block; ++i) {
  85. tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
  86. }
  87. }
  88. }
  89. __syncthreads();
  90. if (threadIdx.y > 0) {
  91. return;
  92. }
  93. // sum up partial sums and write back result
  94. #pragma unroll
  95. for (int j = 0; j < ncols_y; ++j) {
  96. #pragma unroll
  97. for (int i = 0; i < rows_per_cuda_block; ++i) {
  98. #pragma unroll
  99. for (int l = 0; l < nwarps-1; ++l) {
  100. tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
  101. }
  102. tmp[j][i] = warp_reduce_sum(tmp[j][i]);
  103. }
  104. if (threadIdx.x < rows_per_cuda_block) {
  105. dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
  106. }
  107. }
  108. }
  109. template <ggml_type type>
  110. static void mul_mat_vec_q_cuda(
  111. const void * vx, const void * vy, float * dst,
  112. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  113. GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
  114. GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
  115. int id = ggml_cuda_get_device();
  116. int64_t nwarps = 1;
  117. int64_t rows_per_cuda_block = 1;
  118. if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
  119. switch(ncols_y) {
  120. case 1:
  121. nwarps = 4;
  122. rows_per_cuda_block = 1;
  123. break;
  124. case 2:
  125. case 3:
  126. case 4:
  127. nwarps = 4;
  128. rows_per_cuda_block = 2;
  129. break;
  130. case 5:
  131. case 6:
  132. case 7:
  133. case 8:
  134. nwarps = 2;
  135. rows_per_cuda_block = 2;
  136. break;
  137. default:
  138. GGML_ASSERT(false);
  139. break;
  140. }
  141. }
  142. const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
  143. const dim3 block_nums(nblocks, 1, 1);
  144. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  145. switch (ncols_y) {
  146. case 1:
  147. mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  148. break;
  149. case 2:
  150. mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  151. break;
  152. case 3:
  153. mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  154. break;
  155. case 4:
  156. mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  157. break;
  158. case 5:
  159. mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  160. break;
  161. case 6:
  162. mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  163. break;
  164. case 7:
  165. mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  166. break;
  167. case 8:
  168. mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  169. break;
  170. default:
  171. GGML_ASSERT(false);
  172. break;
  173. }
  174. }
  175. static void mul_mat_vec_q4_0_q8_1_cuda(
  176. const void * vx, const void * vy, float * dst,
  177. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  178. mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  179. }
  180. static void mul_mat_vec_q4_1_q8_1_cuda(
  181. const void * vx, const void * vy, float * dst,
  182. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  183. mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  184. }
  185. static void mul_mat_vec_q5_0_q8_1_cuda(
  186. const void * vx, const void * vy, float * dst,
  187. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  188. mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  189. }
  190. static void mul_mat_vec_q5_1_q8_1_cuda(
  191. const void * vx, const void * vy, float * dst,
  192. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  193. mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  194. }
  195. static void mul_mat_vec_q8_0_q8_1_cuda(
  196. const void * vx, const void * vy, float * dst,
  197. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  198. mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  199. }
  200. static void mul_mat_vec_q2_K_q8_1_cuda(
  201. const void * vx, const void * vy, float * dst,
  202. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  203. mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  204. }
  205. static void mul_mat_vec_q3_K_q8_1_cuda(
  206. const void * vx, const void * vy, float * dst,
  207. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  208. mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  209. }
  210. static void mul_mat_vec_q4_K_q8_1_cuda(
  211. const void * vx, const void * vy, float * dst,
  212. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  213. mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  214. }
  215. static void mul_mat_vec_q5_K_q8_1_cuda(
  216. const void * vx, const void * vy, float * dst,
  217. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  218. mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  219. }
  220. static void mul_mat_vec_q6_K_q8_1_cuda(
  221. const void * vx, const void * vy, float * dst,
  222. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  223. mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  224. }
  225. static void mul_mat_vec_iq2_xxs_q8_1_cuda(
  226. const void * vx, const void * vy, float * dst,
  227. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  228. mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  229. }
  230. static void mul_mat_vec_iq2_xs_q8_1_cuda(
  231. const void * vx, const void * vy, float * dst,
  232. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  233. mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  234. }
  235. static void mul_mat_vec_iq2_s_q8_1_cuda(
  236. const void * vx, const void * vy, float * dst,
  237. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  238. mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  239. }
  240. static void mul_mat_vec_iq3_xxs_q8_1_cuda(
  241. const void * vx, const void * vy, float * dst,
  242. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  243. mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  244. }
  245. static void mul_mat_vec_iq1_s_q8_1_cuda(
  246. const void * vx, const void * vy, float * dst,
  247. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  248. mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  249. }
  250. static void mul_mat_vec_iq1_m_q8_1_cuda(
  251. const void * vx, const void * vy, float * dst,
  252. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  253. mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  254. }
  255. static void mul_mat_vec_iq4_nl_q8_1_cuda(
  256. const void * vx, const void * vy, float * dst,
  257. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  258. mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  259. }
  260. static void mul_mat_vec_iq4_xs_q8_1_cuda(
  261. const void * vx, const void * vy, float * dst,
  262. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  263. mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  264. }
  265. static void mul_mat_vec_iq3_s_q8_1_cuda(
  266. const void * vx, const void * vy, float * dst,
  267. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  268. mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  269. }
  270. void ggml_cuda_op_mul_mat_vec_q(
  271. ggml_backend_cuda_context & ctx,
  272. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  273. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  274. const int64_t src1_padded_row_size, cudaStream_t stream) {
  275. const int64_t ne00 = src0->ne[0];
  276. const int64_t row_diff = row_high - row_low;
  277. const int64_t ne10 = src1->ne[0];
  278. GGML_ASSERT(ne10 % QK8_1 == 0);
  279. const int64_t ne0 = dst->ne[0];
  280. int id = ggml_cuda_get_device();
  281. // the main device has a larger memory buffer to hold the results from all GPUs
  282. // nrows_dst == nrows of the matrix that the kernel writes into
  283. const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
  284. switch (src0->type) {
  285. case GGML_TYPE_Q4_0:
  286. mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  287. break;
  288. case GGML_TYPE_Q4_1:
  289. mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  290. break;
  291. case GGML_TYPE_Q5_0:
  292. mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  293. break;
  294. case GGML_TYPE_Q5_1:
  295. mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  296. break;
  297. case GGML_TYPE_Q8_0:
  298. mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  299. break;
  300. case GGML_TYPE_Q2_K:
  301. mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  302. break;
  303. case GGML_TYPE_Q3_K:
  304. mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  305. break;
  306. case GGML_TYPE_Q4_K:
  307. mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  308. break;
  309. case GGML_TYPE_Q5_K:
  310. mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  311. break;
  312. case GGML_TYPE_Q6_K:
  313. mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  314. break;
  315. case GGML_TYPE_IQ2_XXS:
  316. mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  317. break;
  318. case GGML_TYPE_IQ2_XS:
  319. mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  320. break;
  321. case GGML_TYPE_IQ2_S:
  322. mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  323. break;
  324. case GGML_TYPE_IQ3_XXS:
  325. mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  326. break;
  327. case GGML_TYPE_IQ1_S:
  328. mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  329. break;
  330. case GGML_TYPE_IQ1_M:
  331. mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  332. break;
  333. case GGML_TYPE_IQ4_NL:
  334. mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  335. break;
  336. case GGML_TYPE_IQ4_XS:
  337. mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  338. break;
  339. case GGML_TYPE_IQ3_S:
  340. mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  341. break;
  342. default:
  343. GGML_ASSERT(false);
  344. break;
  345. }
  346. GGML_UNUSED(src1);
  347. GGML_UNUSED(dst);
  348. GGML_UNUSED(src1_ddf_i);
  349. GGML_UNUSED(src1_ncols);
  350. GGML_UNUSED(src1_padded_row_size);
  351. }