mmvq.cu 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. #include "mmvq.cuh"
  2. #include "vecdotq.cuh"
  3. typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
  4. template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
  5. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  6. // tell the compiler to use as many registers as it wants, see nwarps definition below
  7. __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
  8. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  9. static __global__ void mul_mat_vec_q(
  10. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  11. const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
  12. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
  13. constexpr int nwarps = 1;
  14. constexpr int rows_per_cuda_block = 1;
  15. #else
  16. constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
  17. constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
  18. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
  19. const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
  20. const int row0 = rows_per_cuda_block*blockIdx.x;
  21. const int blocks_per_row_x = ncols_x / qk;
  22. const int blocks_per_col_y = nrows_y / QK8_1;
  23. constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
  24. // partial sum for each thread
  25. float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
  26. const block_q_t * x = (const block_q_t *) vx;
  27. const block_q8_1 * y = (const block_q8_1 *) vy;
  28. for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
  29. const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
  30. // x block quant index when casting the quants to int
  31. const int kqs = vdr * (tid % (qi/vdr));
  32. #pragma unroll
  33. for (int j = 0; j < ncols_y; ++j) {
  34. #pragma unroll
  35. for (int i = 0; i < rows_per_cuda_block; ++i) {
  36. tmp[j][i] += vec_dot_q_cuda(
  37. &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
  38. }
  39. }
  40. }
  41. __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
  42. if (threadIdx.y > 0) {
  43. #pragma unroll
  44. for (int j = 0; j < ncols_y; ++j) {
  45. #pragma unroll
  46. for (int i = 0; i < rows_per_cuda_block; ++i) {
  47. tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
  48. }
  49. }
  50. }
  51. __syncthreads();
  52. if (threadIdx.y > 0) {
  53. return;
  54. }
  55. // sum up partial sums and write back result
  56. #pragma unroll
  57. for (int j = 0; j < ncols_y; ++j) {
  58. #pragma unroll
  59. for (int i = 0; i < rows_per_cuda_block; ++i) {
  60. #pragma unroll
  61. for (int l = 0; l < nwarps-1; ++l) {
  62. tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
  63. }
  64. tmp[j][i] = warp_reduce_sum(tmp[j][i]);
  65. }
  66. if (threadIdx.x < rows_per_cuda_block) {
  67. dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
  68. }
  69. }
  70. }
  71. template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
  72. static void mul_mat_vec_q_cuda(
  73. const void * vx, const void * vy, float * dst,
  74. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  75. GGML_ASSERT(ncols_x % qk == 0);
  76. GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
  77. int id;
  78. CUDA_CHECK(cudaGetDevice(&id));
  79. int64_t nwarps = 1;
  80. int64_t rows_per_cuda_block = 1;
  81. if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
  82. switch(ncols_y) {
  83. case 1:
  84. nwarps = 4;
  85. rows_per_cuda_block = 1;
  86. break;
  87. case 2:
  88. case 3:
  89. case 4:
  90. nwarps = 4;
  91. rows_per_cuda_block = 2;
  92. break;
  93. case 5:
  94. case 6:
  95. case 7:
  96. case 8:
  97. nwarps = 2;
  98. rows_per_cuda_block = 2;
  99. break;
  100. default:
  101. GGML_ASSERT(false);
  102. break;
  103. }
  104. }
  105. const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
  106. const dim3 block_nums(nblocks, 1, 1);
  107. const dim3 block_dims(WARP_SIZE, nwarps, 1);
  108. switch (ncols_y) {
  109. case 1:
  110. mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
  111. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  112. break;
  113. case 2:
  114. mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
  115. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  116. break;
  117. case 3:
  118. mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
  119. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  120. break;
  121. case 4:
  122. mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
  123. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  124. break;
  125. case 5:
  126. mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
  127. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  128. break;
  129. case 6:
  130. mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
  131. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  132. break;
  133. case 7:
  134. mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
  135. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  136. break;
  137. case 8:
  138. mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
  139. <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
  140. break;
  141. default:
  142. GGML_ASSERT(false);
  143. break;
  144. }
  145. }
  146. static void mul_mat_vec_q4_0_q8_1_cuda(
  147. const void * vx, const void * vy, float * dst,
  148. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  149. mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
  150. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  151. }
  152. static void mul_mat_vec_q4_1_q8_1_cuda(
  153. const void * vx, const void * vy, float * dst,
  154. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  155. mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
  156. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  157. }
  158. static void mul_mat_vec_q5_0_q8_1_cuda(
  159. const void * vx, const void * vy, float * dst,
  160. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  161. mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
  162. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  163. }
  164. static void mul_mat_vec_q5_1_q8_1_cuda(
  165. const void * vx, const void * vy, float * dst,
  166. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  167. mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
  168. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  169. }
  170. static void mul_mat_vec_q8_0_q8_1_cuda(
  171. const void * vx, const void * vy, float * dst,
  172. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  173. mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
  174. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  175. }
  176. static void mul_mat_vec_q2_K_q8_1_cuda(
  177. const void * vx, const void * vy, float * dst,
  178. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  179. mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
  180. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  181. }
  182. static void mul_mat_vec_q3_K_q8_1_cuda(
  183. const void * vx, const void * vy, float * dst,
  184. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  185. mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
  186. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  187. }
  188. static void mul_mat_vec_q4_K_q8_1_cuda(
  189. const void * vx, const void * vy, float * dst,
  190. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  191. mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
  192. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  193. }
  194. static void mul_mat_vec_q5_K_q8_1_cuda(
  195. const void * vx, const void * vy, float * dst,
  196. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  197. mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
  198. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  199. }
  200. static void mul_mat_vec_q6_K_q8_1_cuda(
  201. const void * vx, const void * vy, float * dst,
  202. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  203. mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
  204. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  205. }
  206. static void mul_mat_vec_iq2_xxs_q8_1_cuda(
  207. const void * vx, const void * vy, float * dst,
  208. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  209. mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
  210. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  211. }
  212. static void mul_mat_vec_iq2_xs_q8_1_cuda(
  213. const void * vx, const void * vy, float * dst,
  214. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  215. mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
  216. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  217. }
  218. static void mul_mat_vec_iq2_s_q8_1_cuda(
  219. const void * vx, const void * vy, float * dst,
  220. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  221. mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
  222. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  223. }
  224. static void mul_mat_vec_iq3_xxs_q8_1_cuda(
  225. const void * vx, const void * vy, float * dst,
  226. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  227. mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
  228. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  229. }
  230. static void mul_mat_vec_iq1_s_q8_1_cuda(
  231. const void * vx, const void * vy, float * dst,
  232. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  233. mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
  234. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  235. }
  236. static void mul_mat_vec_iq1_m_q8_1_cuda(
  237. const void * vx, const void * vy, float * dst,
  238. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  239. mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1>
  240. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  241. }
  242. static void mul_mat_vec_iq4_nl_q8_1_cuda(
  243. const void * vx, const void * vy, float * dst,
  244. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  245. mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
  246. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  247. }
  248. static void mul_mat_vec_iq4_xs_q8_1_cuda(
  249. const void * vx, const void * vy, float * dst,
  250. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  251. mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
  252. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  253. }
  254. static void mul_mat_vec_iq3_s_q8_1_cuda(
  255. const void * vx, const void * vy, float * dst,
  256. const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
  257. mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
  258. (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
  259. }
  260. void ggml_cuda_op_mul_mat_vec_q(
  261. ggml_backend_cuda_context & ctx,
  262. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  263. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  264. const int64_t src1_padded_row_size, cudaStream_t stream) {
  265. const int64_t ne00 = src0->ne[0];
  266. const int64_t row_diff = row_high - row_low;
  267. const int64_t ne10 = src1->ne[0];
  268. GGML_ASSERT(ne10 % QK8_1 == 0);
  269. const int64_t ne0 = dst->ne[0];
  270. int id;
  271. CUDA_CHECK(cudaGetDevice(&id));
  272. // the main device has a larger memory buffer to hold the results from all GPUs
  273. // nrows_dst == nrows of the matrix that the kernel writes into
  274. const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
  275. switch (src0->type) {
  276. case GGML_TYPE_Q4_0:
  277. mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  278. break;
  279. case GGML_TYPE_Q4_1:
  280. mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  281. break;
  282. case GGML_TYPE_Q5_0:
  283. mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  284. break;
  285. case GGML_TYPE_Q5_1:
  286. mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  287. break;
  288. case GGML_TYPE_Q8_0:
  289. mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  290. break;
  291. case GGML_TYPE_Q2_K:
  292. mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  293. break;
  294. case GGML_TYPE_Q3_K:
  295. mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  296. break;
  297. case GGML_TYPE_Q4_K:
  298. mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  299. break;
  300. case GGML_TYPE_Q5_K:
  301. mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  302. break;
  303. case GGML_TYPE_Q6_K:
  304. mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  305. break;
  306. case GGML_TYPE_IQ2_XXS:
  307. mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  308. break;
  309. case GGML_TYPE_IQ2_XS:
  310. mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  311. break;
  312. case GGML_TYPE_IQ2_S:
  313. mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  314. break;
  315. case GGML_TYPE_IQ3_XXS:
  316. mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  317. break;
  318. case GGML_TYPE_IQ1_S:
  319. mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  320. break;
  321. case GGML_TYPE_IQ1_M:
  322. mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  323. break;
  324. case GGML_TYPE_IQ4_NL:
  325. mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  326. break;
  327. case GGML_TYPE_IQ4_XS:
  328. mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  329. break;
  330. case GGML_TYPE_IQ3_S:
  331. mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
  332. break;
  333. default:
  334. GGML_ASSERT(false);
  335. break;
  336. }
  337. GGML_UNUSED(src1);
  338. GGML_UNUSED(dst);
  339. GGML_UNUSED(src1_ddf_i);
  340. GGML_UNUSED(src1_ncols);
  341. GGML_UNUSED(src1_padded_row_size);
  342. }