convert.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. #include "convert.hpp"
  2. #include "dequantize.hpp"
  3. #include "presets.hpp"
  4. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  5. static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
  6. const sycl::nd_item<3> &item_ct1) {
  7. const int i = 2 * (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
  8. item_ct1.get_local_id(2));
  9. if (i >= k) {
  10. return;
  11. }
  12. const int ib = i/qk; // block index
  13. const int iqs = (i%qk)/qr; // quant index
  14. const int iybs = i - i%qk; // y block start index
  15. const int y_offset = qr == 1 ? 1 : qk/2;
  16. // dequantize
  17. dfloat2 v;
  18. dequantize_kernel(vx, ib, iqs, v);
  19. y[iybs + iqs + 0] = v.x();
  20. y[iybs + iqs + y_offset] = v.y();
  21. }
  22. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  23. static void dequantize_block_sycl(const void *__restrict__ vx,
  24. dst_t *__restrict__ y, const int k,
  25. dpct::queue_ptr stream) {
  26. const int num_blocks = (k + 2*SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2*SYCL_DEQUANTIZE_BLOCK_SIZE);
  27. {
  28. dpct::has_capability_or_fail(stream->get_device(),
  29. {sycl::aspect::fp16});
  30. stream->parallel_for(
  31. sycl::nd_range<3>(
  32. sycl::range<3>(1, 1, num_blocks) *
  33. sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
  34. sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
  35. [=](sycl::nd_item<3> item_ct1) {
  36. dequantize_block<qk, qr, dequantize_kernel>(vx, y, k, item_ct1);
  37. });
  38. }
  39. }
  40. template <typename dst_t>
  41. static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k,
  42. dpct::queue_ptr stream) {
  43. const int nb = k / QK_K;
  44. #if QK_K == 256
  45. {
  46. dpct::has_capability_or_fail(stream->get_device(),
  47. {sycl::aspect::fp16});
  48. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  49. sycl::range<3>(1, 1, 64),
  50. sycl::range<3>(1, 1, 64)),
  51. [=](sycl::nd_item<3> item_ct1) {
  52. dequantize_block_q2_K(vx, y, item_ct1);
  53. });
  54. }
  55. #else
  56. {
  57. dpct::has_capability_or_fail(stream->get_device(),
  58. {sycl::aspect::fp16});
  59. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  60. sycl::range<3>(1, 1, 32),
  61. sycl::range<3>(1, 1, 32)),
  62. [=](sycl::nd_item<3> item_ct1) {
  63. dequantize_block_q2_K(vx, y, item_ct1);
  64. });
  65. }
  66. #endif
  67. }
  68. template <typename dst_t>
  69. static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k,
  70. dpct::queue_ptr stream) {
  71. const int nb = k / QK_K;
  72. #if QK_K == 256
  73. {
  74. dpct::has_capability_or_fail(stream->get_device(),
  75. {sycl::aspect::fp16});
  76. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  77. sycl::range<3>(1, 1, 64),
  78. sycl::range<3>(1, 1, 64)),
  79. [=](sycl::nd_item<3> item_ct1) {
  80. dequantize_block_q3_K(vx, y, item_ct1);
  81. });
  82. }
  83. #else
  84. {
  85. dpct::has_capability_or_fail(stream->get_device(),
  86. {sycl::aspect::fp16});
  87. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  88. sycl::range<3>(1, 1, 32),
  89. sycl::range<3>(1, 1, 32)),
  90. [=](sycl::nd_item<3> item_ct1) {
  91. dequantize_block_q3_K(vx, y, item_ct1);
  92. });
  93. }
  94. #endif
  95. }
  96. template <typename dst_t>
  97. static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k,
  98. dpct::queue_ptr stream) {
  99. const int nb32 = k / 32;
  100. const int nb = (k + 255) / 256;
  101. {
  102. dpct::has_capability_or_fail(stream->get_device(),
  103. {sycl::aspect::fp16});
  104. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  105. sycl::range<3>(1, 1, 32),
  106. sycl::range<3>(1, 1, 32)),
  107. [=](sycl::nd_item<3> item_ct1) {
  108. dequantize_block_q4_0(vx, y, nb32, item_ct1);
  109. });
  110. }
  111. }
  112. template <typename dst_t>
  113. static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k,
  114. dpct::queue_ptr stream) {
  115. const int nb32 = k / 32;
  116. const int nb = (k + 255) / 256;
  117. {
  118. dpct::has_capability_or_fail(stream->get_device(),
  119. {sycl::aspect::fp16});
  120. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  121. sycl::range<3>(1, 1, 32),
  122. sycl::range<3>(1, 1, 32)),
  123. [=](sycl::nd_item<3> item_ct1) {
  124. dequantize_block_q4_1(vx, y, nb32, item_ct1);
  125. });
  126. }
  127. }
  128. template <typename dst_t>
  129. static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
  130. dpct::queue_ptr stream) {
  131. const int nb = k / QK_K;
  132. {
  133. dpct::has_capability_or_fail(stream->get_device(),
  134. {sycl::aspect::fp16});
  135. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  136. sycl::range<3>(1, 1, 32),
  137. sycl::range<3>(1, 1, 32)),
  138. [=](sycl::nd_item<3> item_ct1) {
  139. dequantize_block_q4_K(vx, y, item_ct1);
  140. });
  141. }
  142. }
  143. template <typename dst_t>
  144. static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k,
  145. dpct::queue_ptr stream) {
  146. const int nb = k / QK_K;
  147. #if QK_K == 256
  148. {
  149. dpct::has_capability_or_fail(stream->get_device(),
  150. {sycl::aspect::fp16});
  151. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  152. sycl::range<3>(1, 1, 64),
  153. sycl::range<3>(1, 1, 64)),
  154. [=](sycl::nd_item<3> item_ct1) {
  155. dequantize_block_q5_K(vx, y, item_ct1);
  156. });
  157. }
  158. #else
  159. {
  160. dpct::has_capability_or_fail(stream->get_device(),
  161. {sycl::aspect::fp16});
  162. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  163. sycl::range<3>(1, 1, 32),
  164. sycl::range<3>(1, 1, 32)),
  165. [=](sycl::nd_item<3> item_ct1) {
  166. dequantize_block_q5_K(vx, y, item_ct1);
  167. });
  168. }
  169. #endif
  170. }
  171. template <typename dst_t>
  172. static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k,
  173. dpct::queue_ptr stream) {
  174. const int nb = k / QK_K;
  175. #if QK_K == 256
  176. {
  177. dpct::has_capability_or_fail(stream->get_device(),
  178. {sycl::aspect::fp16});
  179. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  180. sycl::range<3>(1, 1, 64),
  181. sycl::range<3>(1, 1, 64)),
  182. [=](sycl::nd_item<3> item_ct1) {
  183. dequantize_block_q6_K(vx, y, item_ct1);
  184. });
  185. }
  186. #else
  187. {
  188. dpct::has_capability_or_fail(stream->get_device(),
  189. {sycl::aspect::fp16});
  190. stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  191. sycl::range<3>(1, 1, 32),
  192. sycl::range<3>(1, 1, 32)),
  193. [=](sycl::nd_item<3> item_ct1) {
  194. dequantize_block_q6_K(vx, y, item_ct1);
  195. });
  196. }
  197. #endif
  198. }
  199. template <typename dst_t>
  200. static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int k,
  201. dpct::queue_ptr stream) {
  202. const int nb = k / QK_K;
  203. {
  204. dpct::has_capability_or_fail(stream->get_device(),
  205. {sycl::aspect::fp16});
  206. stream->submit([&](sycl::handler &cgh) {
  207. cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  208. sycl::range<3>(1, 1, 32),
  209. sycl::range<3>(1, 1, 32)),
  210. [=](sycl::nd_item<3> item_ct1) {
  211. dequantize_block_iq1_s(
  212. vx, y, item_ct1, iq1s_grid_gpu
  213. );
  214. });
  215. });
  216. }
  217. }
  218. template <typename dst_t>
  219. static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int k,
  220. dpct::queue_ptr stream) {
  221. const int nb = k / QK_K;
  222. {
  223. dpct::has_capability_or_fail(stream->get_device(),
  224. {sycl::aspect::fp16});
  225. stream->submit([&](sycl::handler &cgh) {
  226. cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  227. sycl::range<3>(1, 1, 32),
  228. sycl::range<3>(1, 1, 32)),
  229. [=](sycl::nd_item<3> item_ct1) {
  230. dequantize_block_iq1_m(
  231. vx, y, item_ct1, iq1s_grid_gpu
  232. );
  233. });
  234. });
  235. }
  236. }
  237. template <typename dst_t>
  238. static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k,
  239. dpct::queue_ptr stream) {
  240. const int nb = k / QK_K;
  241. {
  242. dpct::has_capability_or_fail(stream->get_device(),
  243. {sycl::aspect::fp16});
  244. stream->submit([&](sycl::handler &cgh) {
  245. cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  246. sycl::range<3>(1, 1, 32),
  247. sycl::range<3>(1, 1, 32)),
  248. [=](sycl::nd_item<3> item_ct1) {
  249. dequantize_block_iq2_xxs(
  250. vx, y, item_ct1, iq2xxs_grid,
  251. ksigns_iq2xs, kmask_iq2xs);
  252. });
  253. });
  254. }
  255. }
  256. template <typename dst_t>
  257. static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k,
  258. dpct::queue_ptr stream) {
  259. const int nb = k / QK_K;
  260. {
  261. dpct::has_capability_or_fail(stream->get_device(),
  262. {sycl::aspect::fp16});
  263. stream->submit([&](sycl::handler &cgh) {
  264. cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  265. sycl::range<3>(1, 1, 32),
  266. sycl::range<3>(1, 1, 32)),
  267. [=](sycl::nd_item<3> item_ct1) {
  268. dequantize_block_iq2_xs(
  269. vx, y, item_ct1, iq2xs_grid,
  270. ksigns_iq2xs, kmask_iq2xs);
  271. });
  272. });
  273. }
  274. }
  275. template <typename dst_t>
  276. static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int k,
  277. dpct::queue_ptr stream) {
  278. const int nb = k / QK_K;
  279. {
  280. dpct::has_capability_or_fail(stream->get_device(),
  281. {sycl::aspect::fp16});
  282. stream->submit([&](sycl::handler &cgh) {
  283. cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  284. sycl::range<3>(1, 1, 32),
  285. sycl::range<3>(1, 1, 32)),
  286. [=](sycl::nd_item<3> item_ct1) {
  287. dequantize_block_iq2_s(vx, y, item_ct1);
  288. });
  289. });
  290. }
  291. }
  292. template <typename dst_t>
  293. static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k,
  294. dpct::queue_ptr stream) {
  295. const int nb = k / QK_K;
  296. {
  297. dpct::has_capability_or_fail(stream->get_device(),
  298. {sycl::aspect::fp16});
  299. stream->submit([&](sycl::handler &cgh) {
  300. cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  301. sycl::range<3>(1, 1, 32),
  302. sycl::range<3>(1, 1, 32)),
  303. [=](sycl::nd_item<3> item_ct1) {
  304. dequantize_block_iq3_xxs(
  305. vx, y, item_ct1, iq3xxs_grid,
  306. ksigns_iq2xs, kmask_iq2xs);
  307. });
  308. });
  309. }
  310. }
  311. template <typename dst_t>
  312. static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int k,
  313. dpct::queue_ptr stream) {
  314. const int nb = k / QK_K;
  315. {
  316. dpct::has_capability_or_fail(stream->get_device(),
  317. {sycl::aspect::fp16});
  318. stream->submit([&](sycl::handler &cgh) {
  319. cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  320. sycl::range<3>(1, 1, 32),
  321. sycl::range<3>(1, 1, 32)),
  322. [=](sycl::nd_item<3> item_ct1) {
  323. dequantize_block_iq3_s(
  324. vx, y, item_ct1, kmask_iq2xs, iq3s_grid);
  325. });
  326. });
  327. }
  328. }
  329. template <typename dst_t>
  330. static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int k,
  331. dpct::queue_ptr stream) {
  332. const int nb = (k + QK_K - 1) / QK_K;
  333. #if QK_K == 64
  334. dequantize_row_iq4_nl_sycl(vx, y, k, stream);
  335. #else
  336. {
  337. dpct::has_capability_or_fail(stream->get_device(),
  338. {sycl::aspect::fp16});
  339. stream->submit([&](sycl::handler &cgh) {
  340. cgh.parallel_for(
  341. sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  342. sycl::range<3>(1, 1, 32),
  343. sycl::range<3>(1, 1, 32)),
  344. [=](sycl::nd_item<3> item_ct1) {
  345. dequantize_block_iq4_xs(vx, y, item_ct1);
  346. });
  347. });
  348. }
  349. #endif
  350. }
  351. template <typename dst_t>
  352. static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int k,
  353. dpct::queue_ptr stream) {
  354. const int nb = (k + QK_K - 1) / QK_K;
  355. {
  356. dpct::has_capability_or_fail(stream->get_device(),
  357. {sycl::aspect::fp16});
  358. stream->submit([&](sycl::handler &cgh) {
  359. cgh.parallel_for(
  360. sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
  361. sycl::range<3>(1, 1, 32),
  362. sycl::range<3>(1, 1, 32)),
  363. [=](sycl::nd_item<3> item_ct1) {
  364. dequantize_block_iq4_nl(vx, y, item_ct1);
  365. });
  366. });
  367. }
  368. }
  369. template <typename src_t, typename dst_t>
  370. static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k,
  371. const sycl::nd_item<3> &item_ct1) {
  372. const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
  373. item_ct1.get_local_id(2);
  374. if (i >= k) {
  375. return;
  376. }
  377. const src_t * x = (src_t *) vx;
  378. y[i] = x[i];
  379. }
  380. template <typename src_t, typename dst_t>
  381. static void convert_unary_sycl(const void *__restrict__ vx,
  382. dst_t *__restrict__ y, const int k,
  383. dpct::queue_ptr stream) {
  384. const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE;
  385. {
  386. dpct::has_capability_or_fail(stream->get_device(),
  387. {sycl::aspect::fp16});
  388. stream->parallel_for(
  389. sycl::nd_range<3>(
  390. sycl::range<3>(1, 1, num_blocks) *
  391. sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
  392. sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
  393. [=](sycl::nd_item<3> item_ct1) {
  394. convert_unary<src_t>(vx, y, k, item_ct1);
  395. });
  396. }
  397. }
  398. to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) {
  399. switch (type) {
  400. case GGML_TYPE_Q4_0:
  401. return dequantize_block_sycl<QK4_0, QR4_0, dequantize_q4_0>;
  402. case GGML_TYPE_Q4_1:
  403. return dequantize_block_sycl<QK4_1, QR4_1, dequantize_q4_1>;
  404. case GGML_TYPE_Q5_0:
  405. return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
  406. case GGML_TYPE_Q5_1:
  407. return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
  408. case GGML_TYPE_Q8_0:
  409. return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
  410. case GGML_TYPE_Q2_K:
  411. return dequantize_row_q2_K_sycl;
  412. case GGML_TYPE_Q3_K:
  413. return dequantize_row_q3_K_sycl;
  414. case GGML_TYPE_Q4_K:
  415. return dequantize_row_q4_K_sycl;
  416. case GGML_TYPE_Q5_K:
  417. return dequantize_row_q5_K_sycl;
  418. case GGML_TYPE_Q6_K:
  419. return dequantize_row_q6_K_sycl;
  420. case GGML_TYPE_IQ1_S:
  421. return dequantize_row_iq1_s_sycl;
  422. case GGML_TYPE_IQ1_M:
  423. return dequantize_row_iq1_m_sycl;
  424. case GGML_TYPE_IQ2_XXS:
  425. return dequantize_row_iq2_xxs_sycl;
  426. case GGML_TYPE_IQ2_XS:
  427. return dequantize_row_iq2_xs_sycl;
  428. case GGML_TYPE_IQ2_S:
  429. return dequantize_row_iq2_s_sycl;
  430. case GGML_TYPE_IQ3_XXS:
  431. return dequantize_row_iq3_xxs_sycl;
  432. case GGML_TYPE_IQ3_S:
  433. return dequantize_row_iq3_s_sycl;
  434. case GGML_TYPE_IQ4_XS:
  435. return dequantize_row_iq4_xs_sycl;
  436. case GGML_TYPE_IQ4_NL:
  437. return dequantize_row_iq4_nl_sycl;
  438. case GGML_TYPE_F32:
  439. return convert_unary_sycl<float>;
  440. default:
  441. return nullptr;
  442. }
  443. }
  444. to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) {
  445. switch (type) {
  446. case GGML_TYPE_Q4_0:
  447. return dequantize_row_q4_0_sycl;
  448. case GGML_TYPE_Q4_1:
  449. return dequantize_row_q4_1_sycl;
  450. case GGML_TYPE_Q5_0:
  451. return dequantize_block_sycl<QK5_0, QR5_0, dequantize_q5_0>;
  452. case GGML_TYPE_Q5_1:
  453. return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>;
  454. case GGML_TYPE_Q8_0:
  455. return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>;
  456. case GGML_TYPE_Q2_K:
  457. return dequantize_row_q2_K_sycl;
  458. case GGML_TYPE_Q3_K:
  459. return dequantize_row_q3_K_sycl;
  460. case GGML_TYPE_Q4_K:
  461. return dequantize_row_q4_K_sycl;
  462. case GGML_TYPE_Q5_K:
  463. return dequantize_row_q5_K_sycl;
  464. case GGML_TYPE_Q6_K:
  465. return dequantize_row_q6_K_sycl;
  466. case GGML_TYPE_IQ1_S:
  467. return dequantize_row_iq1_s_sycl;
  468. case GGML_TYPE_IQ1_M:
  469. return dequantize_row_iq1_m_sycl;
  470. case GGML_TYPE_IQ2_XXS:
  471. return dequantize_row_iq2_xxs_sycl;
  472. case GGML_TYPE_IQ2_XS:
  473. return dequantize_row_iq2_xs_sycl;
  474. case GGML_TYPE_IQ2_S:
  475. return dequantize_row_iq2_s_sycl;
  476. case GGML_TYPE_IQ3_XXS:
  477. return dequantize_row_iq3_xxs_sycl;
  478. case GGML_TYPE_IQ3_S:
  479. return dequantize_row_iq3_s_sycl;
  480. case GGML_TYPE_IQ4_XS:
  481. return dequantize_row_iq4_xs_sycl;
  482. case GGML_TYPE_IQ4_NL:
  483. return dequantize_row_iq4_nl_sycl;
  484. case GGML_TYPE_F16:
  485. return convert_unary_sycl<sycl::half>;
  486. default:
  487. return nullptr;
  488. }
  489. }