mmq.cpp 117 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031
  1. //
  2. // MIT license
  3. // Copyright (C) 2024 Intel Corporation
  4. // SPDX-License-Identifier: MIT
  5. //
  6. //
  7. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  8. // See https://llvm.org/LICENSE.txt for license information.
  9. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  10. //
  11. #include "mmq.hpp"
  12. #include "vecdotq.hpp"
  13. typedef void (*allocate_tiles_sycl_t)(
  14. int** x_ql,
  15. sycl::half2** x_dm,
  16. int** x_qh,
  17. int** x_sc);
  18. typedef void (*load_tiles_sycl_t)(
  19. const void* __restrict__ vx,
  20. int* __restrict__ x_ql,
  21. sycl::half2* __restrict__ x_dm,
  22. int* __restrict__ x_qh,
  23. int* __restrict__ x_sc,
  24. const int& i_offset,
  25. const int& i_max,
  26. const int& k,
  27. const int& blocks_per_row);
  28. typedef float (*vec_dot_q_mul_mat_sycl_t)(
  29. const int* __restrict__ x_ql,
  30. const sycl::half2* __restrict__ x_dm,
  31. const int* __restrict__ x_qh,
  32. const int* __restrict__ x_sc,
  33. const int* __restrict__ y_qs,
  34. const sycl::half2* __restrict__ y_ms,
  35. const int& i,
  36. const int& j,
  37. const int& k);
  38. template <int mmq_y>
  39. static __dpct_inline__ void
  40. allocate_tiles_q4_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  41. int *tile_x_qs_q4_0, float *tile_x_d_q4_0) {
  42. (void)x_qh; (void)x_sc;
  43. *x_ql = tile_x_qs_q4_0;
  44. *x_dm = (sycl::half2 *)tile_x_d_q4_0;
  45. }
  46. template <int mmq_y, int nwarps, bool need_check>
  47. static __dpct_inline__ void
  48. load_tiles_q4_0(const void *__restrict__ vx, int *__restrict__ x_ql,
  49. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  50. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  51. const int &k, const int &blocks_per_row) {
  52. (void)x_qh; (void)x_sc;
  53. GGML_SYCL_ASSUME(i_offset >= 0);
  54. GGML_SYCL_ASSUME(i_offset < nwarps);
  55. GGML_SYCL_ASSUME(k >= 0);
  56. GGML_SYCL_ASSUME(k < WARP_SIZE);
  57. const int kbx = k / QI4_0;
  58. const int kqsx = k % QI4_0;
  59. const block_q4_0 * bx0 = (const block_q4_0 *) vx;
  60. float * x_dmf = (float *) x_dm;
  61. #pragma unroll
  62. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  63. int i = i0 + i_offset;
  64. if (need_check) {
  65. i = sycl::min(i, i_max);
  66. }
  67. const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
  68. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
  69. // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
  70. }
  71. const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
  72. const int kbxd = k % blocks_per_tile_x_row;
  73. #pragma unroll
  74. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
  75. int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
  76. if (need_check) {
  77. i = sycl::min(i, i_max);
  78. }
  79. const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  80. x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
  81. }
  82. }
  83. static __dpct_inline__ float vec_dot_q4_0_q8_1_mul_mat(
  84. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  85. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  86. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  87. const int &i, const int &j, const int &k) {
  88. (void)x_qh; (void)x_sc;
  89. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  90. const float * x_dmf = (const float *) x_dm;
  91. int u[2*VDR_Q4_0_Q8_1_MMQ];
  92. #pragma unroll
  93. for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
  94. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  95. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
  96. }
  97. return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
  98. (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
  99. y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  100. }
  101. template <int mmq_y>
  102. static __dpct_inline__ void
  103. allocate_tiles_q4_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  104. int *tile_x_qs_q4_1, sycl::half2 *tile_x_dm_q4_1) {
  105. (void)x_qh; (void)x_sc;
  106. *x_ql = tile_x_qs_q4_1;
  107. *x_dm = tile_x_dm_q4_1;
  108. }
  109. template <int mmq_y, int nwarps, bool need_check>
  110. static __dpct_inline__ void
  111. load_tiles_q4_1(const void *__restrict__ vx, int *__restrict__ x_ql,
  112. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  113. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  114. const int &k, const int &blocks_per_row) {
  115. (void)x_qh; (void)x_sc;
  116. GGML_SYCL_ASSUME(i_offset >= 0);
  117. GGML_SYCL_ASSUME(i_offset < nwarps);
  118. GGML_SYCL_ASSUME(k >= 0);
  119. GGML_SYCL_ASSUME(k < WARP_SIZE);
  120. const int kbx = k / QI4_1;
  121. const int kqsx = k % QI4_1;
  122. const block_q4_1 * bx0 = (const block_q4_1 *) vx;
  123. #pragma unroll
  124. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  125. int i = i0 + i_offset;
  126. if (need_check) {
  127. i = sycl::min(i, i_max);
  128. }
  129. const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
  130. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  131. }
  132. const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
  133. const int kbxd = k % blocks_per_tile_x_row;
  134. #pragma unroll
  135. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
  136. int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
  137. if (need_check) {
  138. i = sycl::min(i, i_max);
  139. }
  140. const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
  141. x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
  142. }
  143. }
  144. static __dpct_inline__ float vec_dot_q4_1_q8_1_mul_mat(
  145. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  146. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  147. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  148. const int &i, const int &j, const int &k) {
  149. (void)x_qh; (void)x_sc;
  150. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  151. int u[2*VDR_Q4_1_Q8_1_MMQ];
  152. #pragma unroll
  153. for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
  154. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  155. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
  156. }
  157. return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
  158. (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
  159. y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  160. }
  161. template <int mmq_y>
  162. static __dpct_inline__ void
  163. allocate_tiles_q5_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  164. int *tile_x_ql_q5_0, float *tile_x_d_q5_0) {
  165. (void)x_qh; (void)x_sc;
  166. *x_ql = tile_x_ql_q5_0;
  167. *x_dm = (sycl::half2 *)tile_x_d_q5_0;
  168. }
  169. template <int mmq_y, int nwarps, bool need_check>
  170. static __dpct_inline__ void
  171. load_tiles_q5_0(const void *__restrict__ vx, int *__restrict__ x_ql,
  172. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  173. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  174. const int &k, const int &blocks_per_row) {
  175. (void)x_qh; (void)x_sc;
  176. GGML_SYCL_ASSUME(i_offset >= 0);
  177. GGML_SYCL_ASSUME(i_offset < nwarps);
  178. GGML_SYCL_ASSUME(k >= 0);
  179. GGML_SYCL_ASSUME(k < WARP_SIZE);
  180. const int kbx = k / QI5_0;
  181. const int kqsx = k % QI5_0;
  182. const block_q5_0 * bx0 = (const block_q5_0 *) vx;
  183. #pragma unroll
  184. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  185. int i = i0 + i_offset;
  186. if (need_check) {
  187. i = sycl::min(i, i_max);
  188. }
  189. const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
  190. const int ql = get_int_from_uint8(bxi->qs, kqsx);
  191. const int qh = get_int_from_uint8(bxi->qh, 0) >> (4 * (k % QI5_0));
  192. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  193. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  194. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  195. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  196. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  197. qs0 = dpct::vectorized_binary<sycl::char4>(
  198. qs0, 0x10101010, dpct::sub_sat()); // subtract 16
  199. x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
  200. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  201. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  202. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  203. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  204. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  205. qs1 = dpct::vectorized_binary<sycl::char4>(
  206. qs1, 0x10101010, dpct::sub_sat()); // subtract 16
  207. x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
  208. }
  209. const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
  210. const int kbxd = k % blocks_per_tile_x_row;
  211. float * x_dmf = (float *) x_dm;
  212. #pragma unroll
  213. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
  214. int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
  215. if (need_check) {
  216. i = sycl::min(i, i_max);
  217. }
  218. const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  219. x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = bxi->d;
  220. }
  221. }
  222. static __dpct_inline__ float vec_dot_q5_0_q8_1_mul_mat(
  223. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  224. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  225. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  226. const int &i, const int &j, const int &k) {
  227. (void)x_qh; (void)x_sc;
  228. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  229. const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
  230. const float * x_dmf = (const float *) x_dm;
  231. const float * y_df = (const float *) y_ds;
  232. int u[2*VDR_Q5_0_Q8_1_MMQ];
  233. #pragma unroll
  234. for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
  235. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  236. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
  237. }
  238. return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
  239. (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  240. }
  241. template <int mmq_y>
  242. static __dpct_inline__ void
  243. allocate_tiles_q5_1(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  244. int *tile_x_ql_q5_1, sycl::half2 *tile_x_dm_q5_1) {
  245. (void)x_qh; (void)x_sc;
  246. *x_ql = tile_x_ql_q5_1;
  247. *x_dm = tile_x_dm_q5_1;
  248. }
  249. template <int mmq_y, int nwarps, bool need_check>
  250. static __dpct_inline__ void
  251. load_tiles_q5_1(const void *__restrict__ vx, int *__restrict__ x_ql,
  252. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  253. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  254. const int &k, const int &blocks_per_row) {
  255. (void)x_qh; (void)x_sc;
  256. GGML_SYCL_ASSUME(i_offset >= 0);
  257. GGML_SYCL_ASSUME(i_offset < nwarps);
  258. GGML_SYCL_ASSUME(k >= 0);
  259. GGML_SYCL_ASSUME(k < WARP_SIZE);
  260. const int kbx = k / QI5_1;
  261. const int kqsx = k % QI5_1;
  262. const block_q5_1 * bx0 = (const block_q5_1 *) vx;
  263. #pragma unroll
  264. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  265. int i = i0 + i_offset;
  266. if (need_check) {
  267. i = sycl::min(i, i_max);
  268. }
  269. const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
  270. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  271. const int qh = get_int_from_uint8_aligned(bxi->qh, 0) >> (4 * (k % QI5_1));
  272. int qs0 = (ql >> 0) & 0x0F0F0F0F;
  273. qs0 |= (qh << 4) & 0x00000010; // 0 -> 4
  274. qs0 |= (qh << 11) & 0x00001000; // 1 -> 12
  275. qs0 |= (qh << 18) & 0x00100000; // 2 -> 20
  276. qs0 |= (qh << 25) & 0x10000000; // 3 -> 28
  277. x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0;
  278. int qs1 = (ql >> 4) & 0x0F0F0F0F;
  279. qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4
  280. qs1 |= (qh >> 5) & 0x00001000; // 17 -> 12
  281. qs1 |= (qh << 2) & 0x00100000; // 18 -> 20
  282. qs1 |= (qh << 9) & 0x10000000; // 19 -> 28
  283. x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1;
  284. }
  285. const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
  286. const int kbxd = k % blocks_per_tile_x_row;
  287. #pragma unroll
  288. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
  289. int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
  290. if (need_check) {
  291. i = sycl::min(i, i_max);
  292. }
  293. const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
  294. x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
  295. }
  296. }
  297. static __dpct_inline__ float vec_dot_q5_1_q8_1_mul_mat(
  298. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  299. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  300. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  301. const int &i, const int &j, const int &k) {
  302. (void)x_qh; (void)x_sc;
  303. const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
  304. const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
  305. int u[2*VDR_Q5_1_Q8_1_MMQ];
  306. #pragma unroll
  307. for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
  308. u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
  309. u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
  310. }
  311. return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
  312. (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
  313. }
  314. template <int mmq_y>
  315. static __dpct_inline__ void
  316. allocate_tiles_q8_0(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  317. int *tile_x_qs_q8_0, float *tile_x_d_q8_0) {
  318. (void)x_qh; (void)x_sc;
  319. *x_ql = tile_x_qs_q8_0;
  320. *x_dm = (sycl::half2 *)tile_x_d_q8_0;
  321. }
  322. template <int mmq_y, int nwarps, bool need_check>
  323. static __dpct_inline__ void
  324. load_tiles_q8_0(const void *__restrict__ vx, int *__restrict__ x_ql,
  325. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  326. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  327. const int &k, const int &blocks_per_row) {
  328. (void)x_qh; (void)x_sc;
  329. GGML_SYCL_ASSUME(i_offset >= 0);
  330. GGML_SYCL_ASSUME(i_offset < nwarps);
  331. GGML_SYCL_ASSUME(k >= 0);
  332. GGML_SYCL_ASSUME(k < WARP_SIZE);
  333. const int kbx = k / QI8_0;
  334. const int kqsx = k % QI8_0;
  335. float * x_dmf = (float *) x_dm;
  336. const block_q8_0 * bx0 = (const block_q8_0 *) vx;
  337. #pragma unroll
  338. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  339. int i = i0 + i_offset;
  340. if (need_check) {
  341. i = sycl::min(i, i_max);
  342. }
  343. const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
  344. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
  345. }
  346. const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
  347. const int kbxd = k % blocks_per_tile_x_row;
  348. #pragma unroll
  349. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
  350. int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
  351. if (need_check) {
  352. i = sycl::min(i, i_max);
  353. }
  354. const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
  355. x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
  356. }
  357. }
  358. static __dpct_inline__ float vec_dot_q8_0_q8_1_mul_mat(
  359. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  360. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  361. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  362. const int &i, const int &j, const int &k) {
  363. (void)x_qh; (void)x_sc;
  364. const float * x_dmf = (const float *) x_dm;
  365. const float * y_df = (const float *) y_ds;
  366. return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMQ>
  367. (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0],
  368. y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
  369. }
  370. template <int mmq_y>
  371. static __dpct_inline__ void
  372. allocate_tiles_q2_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  373. int *tile_x_ql_q2_K, sycl::half2 *tile_x_dm_q2_K,
  374. int *tile_x_sc_q2_K) {
  375. (void)x_qh;
  376. *x_ql = tile_x_ql_q2_K;
  377. *x_dm = tile_x_dm_q2_K;
  378. *x_sc = tile_x_sc_q2_K;
  379. }
  380. template <int mmq_y, int nwarps, bool need_check>
  381. static __dpct_inline__ void
  382. load_tiles_q2_K(const void *__restrict__ vx, int *__restrict__ x_ql,
  383. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  384. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  385. const int &k, const int &blocks_per_row) {
  386. (void)x_qh;
  387. GGML_SYCL_ASSUME(i_offset >= 0);
  388. GGML_SYCL_ASSUME(i_offset < nwarps);
  389. GGML_SYCL_ASSUME(k >= 0);
  390. GGML_SYCL_ASSUME(k < WARP_SIZE);
  391. const int kbx = k / QI2_K;
  392. const int kqsx = k % QI2_K;
  393. const block_q2_K * bx0 = (const block_q2_K *) vx;
  394. #pragma unroll
  395. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  396. int i = i0 + i_offset;
  397. if (need_check) {
  398. i = sycl::min(i, i_max);
  399. }
  400. const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
  401. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  402. }
  403. const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
  404. const int kbxd = k % blocks_per_tile_x_row;
  405. #pragma unroll
  406. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
  407. int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
  408. if (need_check) {
  409. i = sycl::min(i, i_max);
  410. }
  411. const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
  412. x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
  413. }
  414. #pragma unroll
  415. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  416. int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
  417. if (need_check) {
  418. i = sycl::min(i, i_max);
  419. }
  420. const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
  421. x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
  422. }
  423. }
  424. #define VDR_Q2_K_Q8_1_MMQ 2
  425. // contiguous u/y values
  426. static __dpct_inline__ float
  427. vec_dot_q2_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
  428. const uint8_t *__restrict__ scales,
  429. const sycl::half2 &dm2, const float &d8) {
  430. int sumi_d = 0;
  431. int sumi_m = 0;
  432. #pragma unroll
  433. for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
  434. int sumi_d_sc = 0;
  435. const int sc = scales[i0 / (QI8_1/2)];
  436. // fill int with 4x m
  437. int m = sc >> 4;
  438. m |= m << 8;
  439. m |= m << 16;
  440. #pragma unroll
  441. for (int i = i0; i < i0 + QI8_1/2; ++i) {
  442. sumi_d_sc = dpct::dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product
  443. sumi_m = dpct::dp4a(m, u[i],
  444. sumi_m); // multiply sum of q8_1 values with m
  445. }
  446. sumi_d += sumi_d_sc * (sc & 0xF);
  447. }
  448. const sycl::float2 dm2f =
  449. dm2.convert<float, sycl::rounding_mode::automatic>();
  450. return d8 * (dm2f.x() * sumi_d - dm2f.y() * sumi_m);
  451. }
  452. static __dpct_inline__ float vec_dot_q2_K_q8_1_mul_mat(
  453. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  454. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  455. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  456. const int &i, const int &j, const int &k) {
  457. (void)x_qh;
  458. const int kbx = k / QI2_K;
  459. const int ky = (k % QI2_K) * QR2_K;
  460. const float * y_df = (const float *) y_ds;
  461. int v[QR2_K*VDR_Q2_K_Q8_1_MMQ];
  462. const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2);
  463. const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2));
  464. #pragma unroll
  465. for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) {
  466. v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303;
  467. }
  468. const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
  469. const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
  470. return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
  471. }
  472. template <int mmq_y>
  473. static __dpct_inline__ void
  474. allocate_tiles_q3_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  475. int *tile_x_ql_q3_K, sycl::half2 *tile_x_dm_q3_K,
  476. int *tile_x_qh_q3_K, int *tile_x_sc_q3_K) {
  477. *x_ql = tile_x_ql_q3_K;
  478. *x_dm = tile_x_dm_q3_K;
  479. *x_qh = tile_x_qh_q3_K;
  480. *x_sc = tile_x_sc_q3_K;
  481. }
  482. template <int mmq_y, int nwarps, bool need_check>
  483. static __dpct_inline__ void
  484. load_tiles_q3_K(const void *__restrict__ vx, int *__restrict__ x_ql,
  485. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  486. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  487. const int &k, const int &blocks_per_row) {
  488. GGML_SYCL_ASSUME(i_offset >= 0);
  489. GGML_SYCL_ASSUME(i_offset < nwarps);
  490. GGML_SYCL_ASSUME(k >= 0);
  491. GGML_SYCL_ASSUME(k < WARP_SIZE);
  492. const int kbx = k / QI3_K;
  493. const int kqsx = k % QI3_K;
  494. const block_q3_K * bx0 = (const block_q3_K *) vx;
  495. #pragma unroll
  496. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  497. int i = i0 + i_offset;
  498. if (need_check) {
  499. i = sycl::min(i, i_max);
  500. }
  501. const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
  502. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
  503. }
  504. const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
  505. const int kbxd = k % blocks_per_tile_x_row;
  506. float * x_dmf = (float *) x_dm;
  507. #pragma unroll
  508. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
  509. int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
  510. if (need_check) {
  511. i = sycl::min(i, i_max);
  512. }
  513. const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
  514. x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d;
  515. }
  516. #pragma unroll
  517. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
  518. int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
  519. if (need_check) {
  520. i = sycl::min(i, i_max);
  521. }
  522. const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
  523. // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
  524. x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
  525. }
  526. #pragma unroll
  527. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
  528. int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
  529. if (need_check) {
  530. i = sycl::min(i, i_max);
  531. }
  532. const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
  533. const int ksc = k % (QI3_K/4);
  534. const int ksc_low = ksc % (QI3_K/8);
  535. const int shift_low = 4 * (ksc / (QI3_K/8));
  536. const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F;
  537. const int ksc_high = QI3_K/8;
  538. const int shift_high = 2 * ksc;
  539. const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030;
  540. const int sc = dpct::vectorized_binary<sycl::char4>(
  541. sc_low | sc_high, 0x20202020, dpct::sub_sat());
  542. x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc;
  543. }
  544. }
  545. #define VDR_Q3_K_Q8_1_MMQ 2
  546. // contiguous u/y values
  547. static __dpct_inline__ float
  548. vec_dot_q3_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
  549. const int8_t *__restrict__ scales, const float &d3,
  550. const float &d8) {
  551. int sumi = 0;
  552. #pragma unroll
  553. for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) {
  554. int sumi_sc = 0;
  555. for (int i = i0; i < i0 + QI8_1/2; ++i) {
  556. sumi_sc = dpct::dp4a(v[i], u[i], sumi_sc); // SIMD dot product
  557. }
  558. sumi += sumi_sc * scales[i0 / (QI8_1/2)];
  559. }
  560. return d3*d8 * sumi;
  561. }
  562. static __dpct_inline__ float vec_dot_q3_K_q8_1_mul_mat(
  563. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  564. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  565. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  566. const int &i, const int &j, const int &k) {
  567. const int kbx = k / QI3_K;
  568. const int ky = (k % QI3_K) * QR3_K;
  569. const float * x_dmf = (const float *) x_dm;
  570. const float * y_df = (const float *) y_ds;
  571. const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4;
  572. int v[QR3_K*VDR_Q3_K_Q8_1_MMQ];
  573. #pragma unroll
  574. for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) {
  575. const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2);
  576. const int shift = 2 * ((ky % 32) / 8);
  577. const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303;
  578. const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8);
  579. const int vlh = (vh << 2) & 0x04040404;
  580. v[l] = dpct::vectorized_binary<sycl::char4>(vll, vlh, dpct::sub_sat());
  581. }
  582. const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
  583. return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
  584. }
  585. template <int mmq_y>
  586. static __dpct_inline__ void
  587. allocate_tiles_q4_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  588. int *tile_x_ql_q4_K, sycl::half2 *tile_x_dm_q4_K,
  589. int *tile_x_sc_q4_K) {
  590. (void)x_qh;
  591. *x_ql = tile_x_ql_q4_K;
  592. *x_dm = tile_x_dm_q4_K;
  593. *x_sc = tile_x_sc_q4_K;
  594. }
  595. template <int mmq_y, int nwarps, bool need_check>
  596. static __dpct_inline__ void
  597. load_tiles_q4_K(const void *__restrict__ vx, int *__restrict__ x_ql,
  598. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  599. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  600. const int &k, const int &blocks_per_row) {
  601. (void)x_qh;
  602. GGML_SYCL_ASSUME(i_offset >= 0);
  603. GGML_SYCL_ASSUME(i_offset < nwarps);
  604. GGML_SYCL_ASSUME(k >= 0);
  605. GGML_SYCL_ASSUME(k < WARP_SIZE);
  606. const int kbx = k / QI4_K; // == 0 if QK_K == 256
  607. const int kqsx = k % QI4_K; // == k if QK_K == 256
  608. const block_q4_K * bx0 = (const block_q4_K *) vx;
  609. #pragma unroll
  610. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  611. int i = i0 + i_offset;
  612. if (need_check) {
  613. i = sycl::min(i, i_max);
  614. }
  615. const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
  616. x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
  617. }
  618. const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
  619. const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
  620. #pragma unroll
  621. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
  622. int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
  623. if (need_check) {
  624. i = sycl::min(i, i_max);
  625. }
  626. const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
  627. #if QK_K == 256
  628. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
  629. #else
  630. x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]};
  631. #endif
  632. }
  633. #pragma unroll
  634. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  635. int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
  636. if (need_check) {
  637. i = sycl::min(i, i_max);
  638. }
  639. const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
  640. const int * scales = (const int *) bxi->scales;
  641. const int ksc = k % (WARP_SIZE/8);
  642. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  643. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  644. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  645. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  646. }
  647. }
  648. #define VDR_Q4_K_Q8_1_MMQ 8
  649. // contiguous u/y values
  650. static __dpct_inline__ float vec_dot_q4_K_q8_1_impl_mmq(
  651. const int *__restrict__ v, const int *__restrict__ u,
  652. const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
  653. const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
  654. float sumf_d = 0.0f;
  655. float sumf_m = 0.0f;
  656. #pragma unroll
  657. for (int i = 0; i < QR4_K*VDR_Q4_K_Q8_1_MMQ/QI8_1; ++i) {
  658. int sumi_d = 0;
  659. #pragma unroll
  660. for (int j = 0; j < QI8_1; ++j) {
  661. sumi_d = dpct::dp4a((v[j] >> (4 * i)) & 0x0F0F0F0F,
  662. u[i * QI8_1 + j], sumi_d); // SIMD dot product
  663. }
  664. const sycl::float2 ds8f =
  665. ds8[i].convert<float, sycl::rounding_mode::automatic>();
  666. sumf_d += ds8f.x() * (sc[i] * sumi_d);
  667. sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
  668. }
  669. const sycl::float2 dm4f =
  670. dm4.convert<float, sycl::rounding_mode::automatic>();
  671. return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
  672. }
  673. static __dpct_inline__ float vec_dot_q4_K_q8_1_mul_mat(
  674. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  675. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  676. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  677. const int &i, const int &j, const int &k) {
  678. (void)x_qh;
  679. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
  680. const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
  681. return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8,
  682. x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
  683. }
  684. template <int mmq_y>
  685. static __dpct_inline__ void
  686. allocate_tiles_q5_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  687. int *tile_x_ql_q5_K, sycl::half2 *tile_x_dm_q5_K,
  688. int *tile_x_sc_q5_K) {
  689. (void)x_qh;
  690. *x_ql = tile_x_ql_q5_K;
  691. *x_dm = tile_x_dm_q5_K;
  692. *x_sc = tile_x_sc_q5_K;
  693. }
  694. template <int mmq_y, int nwarps, bool need_check>
  695. static __dpct_inline__ void
  696. load_tiles_q5_K(const void *__restrict__ vx, int *__restrict__ x_ql,
  697. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  698. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  699. const int &k, const int &blocks_per_row) {
  700. (void)x_qh;
  701. GGML_SYCL_ASSUME(i_offset >= 0);
  702. GGML_SYCL_ASSUME(i_offset < nwarps);
  703. GGML_SYCL_ASSUME(k >= 0);
  704. GGML_SYCL_ASSUME(k < WARP_SIZE);
  705. const int kbx = k / QI5_K; // == 0 if QK_K == 256
  706. const int kqsx = k % QI5_K; // == k if QK_K == 256
  707. const block_q5_K * bx0 = (const block_q5_K *) vx;
  708. #pragma unroll
  709. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  710. int i = i0 + i_offset;
  711. if (need_check) {
  712. i = sycl::min(i, i_max);
  713. }
  714. const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
  715. const int ky = QR5_K*kqsx;
  716. const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx);
  717. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  718. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  719. const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4));
  720. const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010;
  721. const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010;
  722. const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0;
  723. const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4);
  724. x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0;
  725. x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1;
  726. }
  727. const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
  728. const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
  729. #pragma unroll
  730. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
  731. int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
  732. if (need_check) {
  733. i = sycl::min(i, i_max);
  734. }
  735. const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
  736. #if QK_K == 256
  737. x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
  738. #endif
  739. }
  740. #pragma unroll
  741. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  742. int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
  743. if (need_check) {
  744. i = sycl::min(i, i_max);
  745. }
  746. const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
  747. const int * scales = (const int *) bxi->scales;
  748. const int ksc = k % (WARP_SIZE/8);
  749. // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8
  750. int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits
  751. scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits
  752. x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8;
  753. }
  754. }
  755. #define VDR_Q5_K_Q8_1_MMQ 8
  756. // contiguous u/y values
  757. static __dpct_inline__ float vec_dot_q5_K_q8_1_impl_mmq(
  758. const int *__restrict__ v, const int *__restrict__ u,
  759. const uint8_t *__restrict__ sc, const uint8_t *__restrict__ m,
  760. const sycl::half2 &dm4, const sycl::half2 *__restrict__ ds8) {
  761. float sumf_d = 0.0f;
  762. float sumf_m = 0.0f;
  763. #pragma unroll
  764. for (int i = 0; i < QR5_K*VDR_Q5_K_Q8_1_MMQ/QI8_1; ++i) {
  765. int sumi_d = 0;
  766. #pragma unroll
  767. for (int j = 0; j < QI8_1; ++j) {
  768. sumi_d = dpct::dp4a(v[i * QI8_1 + j], u[i * QI8_1 + j],
  769. sumi_d); // SIMD dot product
  770. }
  771. const sycl::float2 ds8f =
  772. ds8[i].convert<float, sycl::rounding_mode::automatic>();
  773. sumf_d += ds8f.x() * (sc[i] * sumi_d);
  774. sumf_m += ds8f.y() * m[i]; // sum of q8_1 block * q4_K min val
  775. }
  776. const sycl::float2 dm4f =
  777. dm4.convert<float, sycl::rounding_mode::automatic>();
  778. return dm4f.x() * sumf_d - dm4f.y() * sumf_m;
  779. }
  780. static __dpct_inline__ float vec_dot_q5_K_q8_1_mul_mat(
  781. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  782. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  783. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  784. const int &i, const int &j, const int &k) {
  785. (void)x_qh;
  786. const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
  787. const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
  788. const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
  789. return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8,
  790. x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
  791. }
  792. template <int mmq_y>
  793. static __dpct_inline__ void
  794. allocate_tiles_q6_K(int **x_ql, sycl::half2 **x_dm, int **x_qh, int **x_sc,
  795. int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_sc) {
  796. (void)x_qh;
  797. *x_ql = tile_x_ql;
  798. *x_dm = tile_x_dm;
  799. *x_sc = tile_x_sc;
  800. }
  801. template <int mmq_y, int nwarps, bool need_check>
  802. static __dpct_inline__ void
  803. load_tiles_q6_K(const void *__restrict__ vx, int *__restrict__ x_ql,
  804. sycl::half2 *__restrict__ x_dm, int *__restrict__ x_qh,
  805. int *__restrict__ x_sc, const int &i_offset, const int &i_max,
  806. const int &k, const int &blocks_per_row) {
  807. (void)x_qh;
  808. GGML_SYCL_ASSUME(i_offset >= 0);
  809. GGML_SYCL_ASSUME(i_offset < nwarps);
  810. GGML_SYCL_ASSUME(k >= 0);
  811. GGML_SYCL_ASSUME(k < WARP_SIZE);
  812. const int kbx = k / QI6_K; // == 0 if QK_K == 256
  813. const int kqsx = k % QI6_K; // == k if QK_K == 256
  814. const block_q6_K * bx0 = (const block_q6_K *) vx;
  815. #pragma unroll
  816. for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
  817. int i = i0 + i_offset;
  818. if (need_check) {
  819. i = sycl::min(i, i_max);
  820. }
  821. const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
  822. const int ky = QR6_K*kqsx;
  823. const int ql = get_int_from_uint8(bxi->ql, kqsx);
  824. const int ql0 = (ql >> 0) & 0x0F0F0F0F;
  825. const int ql1 = (ql >> 4) & 0x0F0F0F0F;
  826. const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4));
  827. const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030;
  828. const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030;
  829. const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0;
  830. const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2);
  831. x_ql[i * (2 * WARP_SIZE + 1) + kq0] =
  832. dpct::vectorized_binary<sycl::char4>(ql0 | qh0, 0x20202020,
  833. dpct::sub_sat());
  834. x_ql[i * (2 * WARP_SIZE + 1) + kq1] =
  835. dpct::vectorized_binary<sycl::char4>(ql1 | qh1, 0x20202020,
  836. dpct::sub_sat());
  837. }
  838. const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
  839. const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
  840. float * x_dmf = (float *) x_dm;
  841. #pragma unroll
  842. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
  843. int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
  844. if (need_check) {
  845. i = sycl::min(i, i_max);
  846. }
  847. const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
  848. x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d;
  849. }
  850. #pragma unroll
  851. for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
  852. int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
  853. if (need_check) {
  854. i = sycl::min(i, i_max);
  855. }
  856. const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
  857. x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
  858. }
  859. }
  860. #define VDR_Q6_K_Q8_1_MMQ 8
  861. // contiguous u/y values
  862. static __dpct_inline__ float
  863. vec_dot_q6_K_q8_1_impl_mmq(const int *__restrict__ v, const int *__restrict__ u,
  864. const int8_t *__restrict__ sc, const float &d6,
  865. const float *__restrict__ d8) {
  866. float sumf_d = 0.0f;
  867. #pragma unroll
  868. for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
  869. sycl::int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
  870. #pragma unroll
  871. for (int i = i0; i < i0 + 2; ++i) {
  872. sumi_d.x() = dpct::dp4a(v[2 * i + 0], u[2 * i + 0],
  873. sumi_d.x()); // SIMD dot product
  874. sumi_d.x() = dpct::dp4a(v[2 * i + 1], u[2 * i + 1],
  875. sumi_d.x()); // SIMD dot product
  876. sumi_d.y() = dpct::dp4a(v[2 * i + 4], u[2 * i + 4],
  877. sumi_d.y()); // SIMD dot product
  878. sumi_d.y() = dpct::dp4a(v[2 * i + 5], u[2 * i + 5],
  879. sumi_d.y()); // SIMD dot product
  880. }
  881. sumf_d += d8[i0 / 4] *
  882. (sc[i0 / 2 + 0] * sumi_d.x() + sc[i0 / 2 + 1] * sumi_d.y());
  883. }
  884. return d6 * sumf_d;
  885. }
  886. static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat(
  887. const int *__restrict__ x_ql, const sycl::half2 *__restrict__ x_dm,
  888. const int *__restrict__ x_qh, const int *__restrict__ x_sc,
  889. const int *__restrict__ y_qs, const sycl::half2 *__restrict__ y_ds,
  890. const int &i, const int &j, const int &k) {
  891. (void)x_qh;
  892. const float * x_dmf = (const float *) x_dm;
  893. const float * y_df = (const float *) y_ds;
  894. const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
  895. const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
  896. const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
  897. return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
  898. }
  899. template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x,
  900. int mmq_y, int nwarps, load_tiles_sycl_t load_tiles, int vdr,
  901. vec_dot_q_mul_mat_sycl_t vec_dot>
  902. /*
  903. DPCT1110:8: The total declared local variable size in device function mul_mat_q
  904. exceeds 128 bytes and may cause high register pressure. Consult with your
  905. hardware vendor to find the total register size available and adjust the code,
  906. or use smaller sub-group size to avoid high register pressure.
  907. */
  908. static __dpct_inline__ void
  909. mul_mat_q(const void *__restrict__ vx, const void *__restrict__ vy,
  910. float *__restrict__ dst, const int ncols_x, const int nrows_x,
  911. const int ncols_y, const int nrows_y, const int nrows_dst,
  912. int *tile_x_ql, sycl::half2 *tile_x_dm, int *tile_x_qh,
  913. int *tile_x_sc, const sycl::nd_item<3> &item_ct1, int *tile_y_qs,
  914. sycl::half2 *tile_y_ds) {
  915. const block_q_t * x = (const block_q_t *) vx;
  916. const block_q8_1 * y = (const block_q8_1 *) vy;
  917. const int blocks_per_row_x = ncols_x / qk;
  918. const int blocks_per_col_y = nrows_y / QK8_1;
  919. const int blocks_per_warp = WARP_SIZE / qi;
  920. const int & ncols_dst = ncols_y;
  921. const int row_dst_0 = item_ct1.get_group(2) * mmq_y;
  922. const int & row_x_0 = row_dst_0;
  923. const int col_dst_0 = item_ct1.get_group(1) * mmq_x;
  924. const int & col_y_0 = col_dst_0;
  925. float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
  926. for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
  927. load_tiles(x + row_x_0 * blocks_per_row_x + ib0, tile_x_ql, tile_x_dm,
  928. tile_x_qh, tile_x_sc, item_ct1.get_local_id(1),
  929. nrows_x - row_x_0 - 1, item_ct1.get_local_id(2),
  930. blocks_per_row_x);
  931. #pragma unroll
  932. for (int ir = 0; ir < qr; ++ir) {
  933. const int kqs = ir * WARP_SIZE + item_ct1.get_local_id(2);
  934. const int kbxd = kqs / QI8_1;
  935. #pragma unroll
  936. for (int i = 0; i < mmq_x; i += nwarps) {
  937. const int col_y_eff = dpct::min(
  938. (unsigned int)(col_y_0 + item_ct1.get_local_id(1) + i),
  939. ncols_y - 1); // to prevent out-of-bounds memory accesses
  940. const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
  941. const int index_y = (item_ct1.get_local_id(1) + i) * WARP_SIZE +
  942. kqs % WARP_SIZE;
  943. tile_y_qs[index_y] = get_int_from_int8_aligned(
  944. by0->qs, item_ct1.get_local_id(2) % QI8_1);
  945. }
  946. #pragma unroll
  947. for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
  948. const int ids =
  949. (ids0 + item_ct1.get_local_id(1) * QI8_1 +
  950. item_ct1.get_local_id(2) / (WARP_SIZE / QI8_1)) %
  951. mmq_x;
  952. const int kby = item_ct1.get_local_id(2) % (WARP_SIZE / QI8_1);
  953. const int col_y_eff = sycl::min(col_y_0 + ids, ncols_y - 1);
  954. // if the sum is not needed it's faster to transform the scale to f32 ahead of time
  955. const sycl::half2 *dsi_src =
  956. &y[col_y_eff * blocks_per_col_y + ib0 * (qk / QK8_1) +
  957. ir * (WARP_SIZE / QI8_1) + kby]
  958. .ds;
  959. sycl::half2 *dsi_dst =
  960. &tile_y_ds[ids * (WARP_SIZE / QI8_1) + kby];
  961. if (need_sum) {
  962. *dsi_dst = *dsi_src;
  963. } else {
  964. float * dfi_dst = (float *) dsi_dst;
  965. *dfi_dst = (*dsi_src)[0];
  966. }
  967. }
  968. /*
  969. DPCT1118:9: SYCL group functions and algorithms must be encountered
  970. in converged control flow. You may need to adjust the code.
  971. */
  972. /*
  973. DPCT1065:56: Consider replacing sycl::nd_item::barrier() with
  974. sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
  975. better performance if there is no access to global memory.
  976. */
  977. item_ct1.barrier();
  978. // #pragma unroll // unrolling this loop causes too much register pressure
  979. for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
  980. #pragma unroll
  981. for (int j = 0; j < mmq_x; j += nwarps) {
  982. #pragma unroll
  983. for (int i = 0; i < mmq_y; i += WARP_SIZE) {
  984. sum[i / WARP_SIZE][j / nwarps] += vec_dot(
  985. tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
  986. tile_y_qs, tile_y_ds, item_ct1.get_local_id(2) + i,
  987. item_ct1.get_local_id(1) + j, k);
  988. }
  989. }
  990. }
  991. /*
  992. DPCT1118:10: SYCL group functions and algorithms must be encountered
  993. in converged control flow. You may need to adjust the code.
  994. */
  995. /*
  996. DPCT1065:57: Consider replacing sycl::nd_item::barrier() with
  997. sycl::nd_item::barrier(sycl::access::fence_space::local_space) for
  998. better performance if there is no access to global memory.
  999. */
  1000. item_ct1.barrier();
  1001. }
  1002. }
  1003. #pragma unroll
  1004. for (int j = 0; j < mmq_x; j += nwarps) {
  1005. const int col_dst = col_dst_0 + j + item_ct1.get_local_id(1);
  1006. if (col_dst >= ncols_dst) {
  1007. return;
  1008. }
  1009. #pragma unroll
  1010. for (int i = 0; i < mmq_y; i += WARP_SIZE) {
  1011. const int row_dst = row_dst_0 + item_ct1.get_local_id(2) + i;
  1012. if (row_dst >= nrows_dst) {
  1013. continue;
  1014. }
  1015. dst[col_dst*nrows_dst + row_dst] = sum[i/WARP_SIZE][j/nwarps];
  1016. }
  1017. }
  1018. }
  1019. #define MMQ_X_Q4_0_RDNA2 64
  1020. #define MMQ_Y_Q4_0_RDNA2 128
  1021. #define NWARPS_Q4_0_RDNA2 8
  1022. #define MMQ_X_Q4_0_RDNA1 64
  1023. #define MMQ_Y_Q4_0_RDNA1 64
  1024. #define NWARPS_Q4_0_RDNA1 8
  1025. #if defined(SYCL_USE_XMX)
  1026. #define MMQ_X_Q4_0_AMPERE 4
  1027. #define MMQ_Y_Q4_0_AMPERE 32
  1028. #define NWARPS_Q4_0_AMPERE 4
  1029. #else
  1030. #define MMQ_X_Q4_0_AMPERE 64
  1031. #define MMQ_Y_Q4_0_AMPERE 128
  1032. #define NWARPS_Q4_0_AMPERE 4
  1033. #endif
  1034. #define MMQ_X_Q4_0_PASCAL 64
  1035. #define MMQ_Y_Q4_0_PASCAL 64
  1036. #define NWARPS_Q4_0_PASCAL 8
  1037. template <bool need_check> static void
  1038. mul_mat_q4_0(
  1039. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1040. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1041. const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_0, float *tile_x_d_q4_0,
  1042. int *tile_y_qs, sycl::half2 *tile_y_ds) {
  1043. int * tile_x_ql = nullptr;
  1044. sycl::half2 *tile_x_dm = nullptr;
  1045. int * tile_x_qh = nullptr;
  1046. int * tile_x_sc = nullptr;
  1047. //sycl_todo: change according to hardware
  1048. const int mmq_x = MMQ_X_Q4_0_AMPERE;
  1049. const int mmq_y = MMQ_Y_Q4_0_AMPERE;
  1050. const int nwarps = NWARPS_Q4_0_AMPERE;
  1051. allocate_tiles_q4_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1052. tile_x_qs_q4_0, tile_x_d_q4_0);
  1053. mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps,
  1054. load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ,
  1055. vec_dot_q4_0_q8_1_mul_mat>(
  1056. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1057. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1058. }
  1059. #define MMQ_X_Q4_1_RDNA2 64
  1060. #define MMQ_Y_Q4_1_RDNA2 128
  1061. #define NWARPS_Q4_1_RDNA2 8
  1062. #define MMQ_X_Q4_1_RDNA1 64
  1063. #define MMQ_Y_Q4_1_RDNA1 64
  1064. #define NWARPS_Q4_1_RDNA1 8
  1065. #if defined(SYCL_USE_XMX)
  1066. #define MMQ_X_Q4_1_AMPERE 4
  1067. #define MMQ_Y_Q4_1_AMPERE 32
  1068. #define NWARPS_Q4_1_AMPERE 4
  1069. #else
  1070. #define MMQ_X_Q4_1_AMPERE 64
  1071. #define MMQ_Y_Q4_1_AMPERE 128
  1072. #define NWARPS_Q4_1_AMPERE 4
  1073. #endif
  1074. #define MMQ_X_Q4_1_PASCAL 64
  1075. #define MMQ_Y_Q4_1_PASCAL 64
  1076. #define NWARPS_Q4_1_PASCAL 8
  1077. template <bool need_check> static void
  1078. mul_mat_q4_1(
  1079. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1080. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1081. const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q4_1,
  1082. sycl::half2 *tile_x_dm_q4_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
  1083. int * tile_x_ql = nullptr;
  1084. sycl::half2 *tile_x_dm = nullptr;
  1085. int * tile_x_qh = nullptr;
  1086. int * tile_x_sc = nullptr;
  1087. //sycl_todo: change according to hardware
  1088. const int mmq_x = MMQ_X_Q4_1_AMPERE;
  1089. const int mmq_y = MMQ_Y_Q4_1_AMPERE;
  1090. const int nwarps = NWARPS_Q4_1_AMPERE;
  1091. allocate_tiles_q4_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1092. tile_x_qs_q4_1, tile_x_dm_q4_1);
  1093. mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps,
  1094. load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ,
  1095. vec_dot_q4_1_q8_1_mul_mat>(
  1096. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1097. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1098. }
  1099. #define MMQ_X_Q5_0_RDNA2 64
  1100. #define MMQ_Y_Q5_0_RDNA2 128
  1101. #define NWARPS_Q5_0_RDNA2 8
  1102. #define MMQ_X_Q5_0_RDNA1 64
  1103. #define MMQ_Y_Q5_0_RDNA1 64
  1104. #define NWARPS_Q5_0_RDNA1 8
  1105. #if defined(SYCL_USE_XMX)
  1106. #define MMQ_X_Q5_0_AMPERE 4
  1107. #define MMQ_Y_Q5_0_AMPERE 32
  1108. #define NWARPS_Q5_0_AMPERE 4
  1109. #else
  1110. #define MMQ_X_Q5_0_AMPERE 128
  1111. #define MMQ_Y_Q5_0_AMPERE 64
  1112. #define NWARPS_Q5_0_AMPERE 4
  1113. #endif
  1114. #define MMQ_X_Q5_0_PASCAL 64
  1115. #define MMQ_Y_Q5_0_PASCAL 64
  1116. #define NWARPS_Q5_0_PASCAL 8
  1117. template <bool need_check> static void
  1118. mul_mat_q5_0(
  1119. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1120. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1121. const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_0, float *tile_x_d_q5_0,
  1122. int *tile_y_qs, sycl::half2 *tile_y_ds) {
  1123. int * tile_x_ql = nullptr;
  1124. sycl::half2 *tile_x_dm = nullptr;
  1125. int * tile_x_qh = nullptr;
  1126. int * tile_x_sc = nullptr;
  1127. //sycl_todo: change according to hardware
  1128. const int mmq_x = MMQ_X_Q5_0_AMPERE;
  1129. const int mmq_y = MMQ_Y_Q5_0_AMPERE;
  1130. const int nwarps = NWARPS_Q5_0_AMPERE;
  1131. allocate_tiles_q5_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1132. tile_x_ql_q5_0, tile_x_d_q5_0);
  1133. mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps,
  1134. load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ,
  1135. vec_dot_q5_0_q8_1_mul_mat>(
  1136. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1137. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1138. }
  1139. #define MMQ_X_Q5_1_RDNA2 64
  1140. #define MMQ_Y_Q5_1_RDNA2 128
  1141. #define NWARPS_Q5_1_RDNA2 8
  1142. #define MMQ_X_Q5_1_RDNA1 64
  1143. #define MMQ_Y_Q5_1_RDNA1 64
  1144. #define NWARPS_Q5_1_RDNA1 8
  1145. #if defined(SYCL_USE_XMX)
  1146. #define MMQ_X_Q5_1_AMPERE 4
  1147. #define MMQ_Y_Q5_1_AMPERE 32
  1148. #define NWARPS_Q5_1_AMPERE 4
  1149. #else
  1150. #define MMQ_X_Q5_1_AMPERE 128
  1151. #define MMQ_Y_Q5_1_AMPERE 64
  1152. #define NWARPS_Q5_1_AMPERE 4
  1153. #endif
  1154. #define MMQ_X_Q5_1_PASCAL 64
  1155. #define MMQ_Y_Q5_1_PASCAL 64
  1156. #define NWARPS_Q5_1_PASCAL 8
  1157. template <bool need_check> static void
  1158. mul_mat_q5_1(
  1159. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1160. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1161. const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_1,
  1162. sycl::half2 *tile_x_dm_q5_1, int *tile_y_qs, sycl::half2 *tile_y_ds) {
  1163. int * tile_x_ql = nullptr;
  1164. sycl::half2 *tile_x_dm = nullptr;
  1165. int * tile_x_qh = nullptr;
  1166. int * tile_x_sc = nullptr;
  1167. //sycl_todo: change according to hardware
  1168. const int mmq_x = MMQ_X_Q5_1_AMPERE;
  1169. const int mmq_y = MMQ_Y_Q5_1_AMPERE;
  1170. const int nwarps = NWARPS_Q5_1_AMPERE;
  1171. allocate_tiles_q5_1<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1172. tile_x_ql_q5_1, tile_x_dm_q5_1);
  1173. mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps,
  1174. load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ,
  1175. vec_dot_q5_1_q8_1_mul_mat>(
  1176. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1177. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1178. }
  1179. #define MMQ_X_Q8_0_RDNA2 64
  1180. #define MMQ_Y_Q8_0_RDNA2 128
  1181. #define NWARPS_Q8_0_RDNA2 8
  1182. #define MMQ_X_Q8_0_RDNA1 64
  1183. #define MMQ_Y_Q8_0_RDNA1 64
  1184. #define NWARPS_Q8_0_RDNA1 8
  1185. #if defined(SYCL_USE_XMX)
  1186. #define MMQ_X_Q8_0_AMPERE 4
  1187. #define MMQ_Y_Q8_0_AMPERE 32
  1188. #define NWARPS_Q8_0_AMPERE 4
  1189. #else
  1190. #define MMQ_X_Q8_0_AMPERE 128
  1191. #define MMQ_Y_Q8_0_AMPERE 64
  1192. #define NWARPS_Q8_0_AMPERE 4
  1193. #endif
  1194. #define MMQ_X_Q8_0_PASCAL 64
  1195. #define MMQ_Y_Q8_0_PASCAL 64
  1196. #define NWARPS_Q8_0_PASCAL 8
  1197. template <bool need_check> static void
  1198. mul_mat_q8_0(
  1199. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1200. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1201. const sycl::nd_item<3> &item_ct1, int *tile_x_qs_q8_0, float *tile_x_d_q8_0,
  1202. int *tile_y_qs, sycl::half2 *tile_y_ds) {
  1203. int * tile_x_ql = nullptr;
  1204. sycl::half2 *tile_x_dm = nullptr;
  1205. int * tile_x_qh = nullptr;
  1206. int * tile_x_sc = nullptr;
  1207. //sycl_todo: change according to hardware
  1208. const int mmq_x = MMQ_X_Q8_0_AMPERE;
  1209. const int mmq_y = MMQ_Y_Q8_0_AMPERE;
  1210. const int nwarps = NWARPS_Q8_0_AMPERE;
  1211. allocate_tiles_q8_0<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1212. tile_x_qs_q8_0, tile_x_d_q8_0);
  1213. mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps,
  1214. load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ,
  1215. vec_dot_q8_0_q8_1_mul_mat>(
  1216. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1217. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1218. }
  1219. #define MMQ_X_Q2_K_RDNA2 64
  1220. #define MMQ_Y_Q2_K_RDNA2 128
  1221. #define NWARPS_Q2_K_RDNA2 8
  1222. #define MMQ_X_Q2_K_RDNA1 128
  1223. #define MMQ_Y_Q2_K_RDNA1 32
  1224. #define NWARPS_Q2_K_RDNA1 8
  1225. #if defined(SYCL_USE_XMX)
  1226. #define MMQ_X_Q2_K_AMPERE 4
  1227. #define MMQ_Y_Q2_K_AMPERE 32
  1228. #define NWARPS_Q2_K_AMPERE 4
  1229. #else
  1230. #define MMQ_X_Q2_K_AMPERE 64
  1231. #define MMQ_Y_Q2_K_AMPERE 128
  1232. #define NWARPS_Q2_K_AMPERE 4
  1233. #endif
  1234. #define MMQ_X_Q2_K_PASCAL 64
  1235. #define MMQ_Y_Q2_K_PASCAL 64
  1236. #define NWARPS_Q2_K_PASCAL 8
  1237. template <bool need_check> static void
  1238. mul_mat_q2_K(
  1239. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1240. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1241. const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q2_K,
  1242. sycl::half2 *tile_x_dm_q2_K, int *tile_x_sc_q2_K, int *tile_y_qs,
  1243. sycl::half2 *tile_y_ds) {
  1244. int * tile_x_ql = nullptr;
  1245. sycl::half2 *tile_x_dm = nullptr;
  1246. int * tile_x_qh = nullptr;
  1247. int * tile_x_sc = nullptr;
  1248. //sycl_todo: change according to hardware
  1249. const int mmq_x = MMQ_X_Q2_K_AMPERE;
  1250. const int mmq_y = MMQ_Y_Q2_K_AMPERE;
  1251. const int nwarps = NWARPS_Q2_K_AMPERE;
  1252. allocate_tiles_q2_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1253. tile_x_ql_q2_K, tile_x_dm_q2_K, tile_x_sc_q2_K);
  1254. mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps,
  1255. load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ,
  1256. vec_dot_q2_K_q8_1_mul_mat>(
  1257. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1258. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1259. }
  1260. #define MMQ_X_Q3_K_RDNA2 128
  1261. #define MMQ_Y_Q3_K_RDNA2 64
  1262. #define NWARPS_Q3_K_RDNA2 8
  1263. #define MMQ_X_Q3_K_RDNA1 32
  1264. #define MMQ_Y_Q3_K_RDNA1 128
  1265. #define NWARPS_Q3_K_RDNA1 8
  1266. #if defined(SYCL_USE_XMX)
  1267. #define MMQ_X_Q3_K_AMPERE 4
  1268. #define MMQ_Y_Q3_K_AMPERE 32
  1269. #define NWARPS_Q3_K_AMPERE 4
  1270. #else
  1271. #define MMQ_X_Q3_K_AMPERE 128
  1272. #define MMQ_Y_Q3_K_AMPERE 128
  1273. #define NWARPS_Q3_K_AMPERE 4
  1274. #endif
  1275. #define MMQ_X_Q3_K_PASCAL 64
  1276. #define MMQ_Y_Q3_K_PASCAL 64
  1277. #define NWARPS_Q3_K_PASCAL 8
  1278. template <bool need_check> static void
  1279. mul_mat_q3_K(
  1280. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1281. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1282. const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q3_K,
  1283. sycl::half2 *tile_x_dm_q3_K, int *tile_x_qh_q3_K, int *tile_x_sc_q3_K,
  1284. int *tile_y_qs, sycl::half2 *tile_y_ds) {
  1285. int * tile_x_ql = nullptr;
  1286. sycl::half2 *tile_x_dm = nullptr;
  1287. int * tile_x_qh = nullptr;
  1288. int * tile_x_sc = nullptr;
  1289. //sycl_todo: change according to hardware
  1290. const int mmq_x = MMQ_X_Q3_K_AMPERE;
  1291. const int mmq_y = MMQ_Y_Q3_K_AMPERE;
  1292. const int nwarps = NWARPS_Q3_K_AMPERE;
  1293. allocate_tiles_q3_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1294. tile_x_ql_q3_K, tile_x_dm_q3_K, tile_x_qh_q3_K,
  1295. tile_x_sc_q3_K);
  1296. mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps,
  1297. load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ,
  1298. vec_dot_q3_K_q8_1_mul_mat>(
  1299. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1300. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1301. }
  1302. #define MMQ_X_Q4_K_RDNA2 64
  1303. #define MMQ_Y_Q4_K_RDNA2 128
  1304. #define NWARPS_Q4_K_RDNA2 8
  1305. #define MMQ_X_Q4_K_RDNA1 32
  1306. #define MMQ_Y_Q4_K_RDNA1 64
  1307. #define NWARPS_Q4_K_RDNA1 8
  1308. #if defined(SYCL_USE_XMX)
  1309. #define MMQ_X_Q4_K_AMPERE 4
  1310. #define MMQ_Y_Q4_K_AMPERE 32
  1311. #define NWARPS_Q4_K_AMPERE 4
  1312. #else
  1313. #define MMQ_X_Q4_K_AMPERE 64
  1314. #define MMQ_Y_Q4_K_AMPERE 128
  1315. #define NWARPS_Q4_K_AMPERE 4
  1316. #endif
  1317. #define MMQ_X_Q4_K_PASCAL 64
  1318. #define MMQ_Y_Q4_K_PASCAL 64
  1319. #define NWARPS_Q4_K_PASCAL 8
  1320. template <bool need_check> static void
  1321. mul_mat_q4_K(
  1322. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1323. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1324. const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q4_K,
  1325. sycl::half2 *tile_x_dm_q4_K, int *tile_x_sc_q4_K, int *tile_y_qs,
  1326. sycl::half2 *tile_y_ds) {
  1327. int * tile_x_ql = nullptr;
  1328. sycl::half2 *tile_x_dm = nullptr;
  1329. int * tile_x_qh = nullptr;
  1330. int * tile_x_sc = nullptr;
  1331. //sycl_todo: change according to hardware
  1332. const int mmq_x = MMQ_X_Q4_K_AMPERE;
  1333. const int mmq_y = MMQ_Y_Q4_K_AMPERE;
  1334. const int nwarps = NWARPS_Q4_K_AMPERE;
  1335. allocate_tiles_q4_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1336. tile_x_ql_q4_K, tile_x_dm_q4_K, tile_x_sc_q4_K);
  1337. mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps,
  1338. load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ,
  1339. vec_dot_q4_K_q8_1_mul_mat>(
  1340. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1341. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1342. }
  1343. #define MMQ_X_Q5_K_RDNA2 64
  1344. #define MMQ_Y_Q5_K_RDNA2 128
  1345. #define NWARPS_Q5_K_RDNA2 8
  1346. #define MMQ_X_Q5_K_RDNA1 32
  1347. #define MMQ_Y_Q5_K_RDNA1 64
  1348. #define NWARPS_Q5_K_RDNA1 8
  1349. #if defined(SYCL_USE_XMX)
  1350. #define MMQ_X_Q5_K_AMPERE 4
  1351. #define MMQ_Y_Q5_K_AMPERE 32
  1352. #define NWARPS_Q5_K_AMPERE 4
  1353. #else
  1354. #define MMQ_X_Q5_K_AMPERE 64
  1355. #define MMQ_Y_Q5_K_AMPERE 128
  1356. #define NWARPS_Q5_K_AMPERE 4
  1357. #endif
  1358. #define MMQ_X_Q5_K_PASCAL 64
  1359. #define MMQ_Y_Q5_K_PASCAL 64
  1360. #define NWARPS_Q5_K_PASCAL 8
  1361. template <bool need_check> static void
  1362. mul_mat_q5_K(
  1363. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1364. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1365. const sycl::nd_item<3> &item_ct1, int *tile_x_ql_q5_K,
  1366. sycl::half2 *tile_x_dm_q5_K, int *tile_x_sc_q5_K, int *tile_y_qs,
  1367. sycl::half2 *tile_y_ds) {
  1368. int * tile_x_ql = nullptr;
  1369. sycl::half2 *tile_x_dm = nullptr;
  1370. int * tile_x_qh = nullptr;
  1371. int * tile_x_sc = nullptr;
  1372. //sycl_todo: change according to hardware
  1373. const int mmq_x = MMQ_X_Q5_K_AMPERE;
  1374. const int mmq_y = MMQ_Y_Q5_K_AMPERE;
  1375. const int nwarps = NWARPS_Q5_K_AMPERE;
  1376. allocate_tiles_q5_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1377. tile_x_ql_q5_K, tile_x_dm_q5_K, tile_x_sc_q5_K);
  1378. mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps,
  1379. load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ,
  1380. vec_dot_q5_K_q8_1_mul_mat>(
  1381. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1382. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1383. }
  1384. #define MMQ_X_Q6_K_RDNA2 64
  1385. #define MMQ_Y_Q6_K_RDNA2 128
  1386. #define NWARPS_Q6_K_RDNA2 8
  1387. #define MMQ_X_Q6_K_RDNA1 32
  1388. #define MMQ_Y_Q6_K_RDNA1 64
  1389. #define NWARPS_Q6_K_RDNA1 8
  1390. #if defined(SYCL_USE_XMX)
  1391. #define MMQ_X_Q6_K_AMPERE 4
  1392. #define MMQ_Y_Q6_K_AMPERE 32
  1393. #define NWARPS_Q6_K_AMPERE 4
  1394. #else
  1395. #define MMQ_X_Q6_K_AMPERE 64
  1396. #define MMQ_Y_Q6_K_AMPERE 64
  1397. #define NWARPS_Q6_K_AMPERE 4
  1398. #endif
  1399. #define MMQ_X_Q6_K_PASCAL 64
  1400. #define MMQ_Y_Q6_K_PASCAL 64
  1401. #define NWARPS_Q6_K_PASCAL 8
  1402. template <bool need_check> static void
  1403. mul_mat_q6_K(
  1404. const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
  1405. const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst,
  1406. const sycl::nd_item<3> &item_ct1, int *tile_x_ql, sycl::half2 *tile_x_dm,
  1407. int *tile_x_sc, int *tile_y_qs, sycl::half2 *tile_y_ds) {
  1408. // int * tile_x_ql = nullptr;
  1409. // sycl::half2 *tile_x_dm = nullptr;
  1410. int * tile_x_qh = nullptr;
  1411. // int * tile_x_sc = nullptr;
  1412. //sycl_todo: change according to hardware
  1413. const int mmq_x = MMQ_X_Q6_K_AMPERE;
  1414. const int mmq_y = MMQ_Y_Q6_K_AMPERE;
  1415. const int nwarps = NWARPS_Q6_K_AMPERE;
  1416. allocate_tiles_q6_K<mmq_y>(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc,
  1417. tile_x_ql, tile_x_dm, tile_x_sc);
  1418. mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps,
  1419. load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ,
  1420. vec_dot_q6_K_q8_1_mul_mat>(
  1421. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst, tile_x_ql,
  1422. tile_x_dm, tile_x_qh, tile_x_sc, item_ct1, tile_y_qs, tile_y_ds);
  1423. }
  1424. static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy,
  1425. float *dst, const int ncols_x,
  1426. const int nrows_x, const int ncols_y,
  1427. const int nrows_y, const int nrows_dst,
  1428. dpct::queue_ptr stream) try {
  1429. int id;
  1430. SYCL_CHECK(
  1431. CHECK_TRY_ERROR(id = get_current_device_id()));
  1432. const int compute_capability = ggml_sycl_info().devices[id].cc;
  1433. int mmq_x, mmq_y, nwarps;
  1434. if (compute_capability >= VER_GEN13) {
  1435. mmq_x = MMQ_X_Q4_0_RDNA2;
  1436. mmq_y = MMQ_Y_Q4_0_RDNA2;
  1437. nwarps = NWARPS_Q4_0_RDNA2;
  1438. } else if (compute_capability >= VER_GEN12) {
  1439. mmq_x = MMQ_X_Q4_0_RDNA1;
  1440. mmq_y = MMQ_Y_Q4_0_RDNA1;
  1441. nwarps = NWARPS_Q4_0_RDNA1;
  1442. } else if (compute_capability >= VER_GEN9) {
  1443. mmq_x = MMQ_X_Q4_0_AMPERE;
  1444. mmq_y = MMQ_Y_Q4_0_AMPERE;
  1445. nwarps = NWARPS_Q4_0_AMPERE;
  1446. } else if (compute_capability >= VER_4VEC) {
  1447. mmq_x = MMQ_X_Q4_0_PASCAL;
  1448. mmq_y = MMQ_Y_Q4_0_PASCAL;
  1449. nwarps = NWARPS_Q4_0_PASCAL;
  1450. } else {
  1451. GGML_ASSERT(false);
  1452. }
  1453. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1454. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1455. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  1456. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  1457. if (nrows_x % mmq_y == 0) {
  1458. const bool need_check = false;
  1459. /*
  1460. DPCT1049:20: The work-group size passed to the SYCL kernel may exceed
  1461. the limit. To get the device limit, query
  1462. info::device::max_work_group_size. Adjust the work-group size if needed.
  1463. */
  1464. {
  1465. dpct::has_capability_or_fail(stream->get_device(),
  1466. {sycl::aspect::fp16});
  1467. stream->submit([&](sycl::handler &cgh) {
  1468. sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
  1469. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  1470. sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
  1471. sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
  1472. cgh);
  1473. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1474. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1475. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1476. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1477. cgh.parallel_for(
  1478. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1479. [=](sycl::nd_item<3> item_ct1) {
  1480. mul_mat_q4_0<need_check>(
  1481. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1482. nrows_dst, item_ct1,
  1483. tile_x_qs_q4_0_acc_ct1.get_pointer(),
  1484. tile_x_d_q4_0_acc_ct1.get_pointer(),
  1485. tile_y_qs_acc_ct1.get_pointer(),
  1486. tile_y_ds_acc_ct1.get_pointer());
  1487. });
  1488. });
  1489. }
  1490. } else {
  1491. const bool need_check = true;
  1492. /*
  1493. DPCT1049:21: The work-group size passed to the SYCL kernel may exceed
  1494. the limit. To get the device limit, query
  1495. info::device::max_work_group_size. Adjust the work-group size if needed.
  1496. */
  1497. {
  1498. dpct::has_capability_or_fail(stream->get_device(),
  1499. {sycl::aspect::fp16});
  1500. stream->submit([&](sycl::handler &cgh) {
  1501. sycl::local_accessor<int, 1> tile_x_qs_q4_0_acc_ct1(
  1502. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  1503. sycl::local_accessor<float, 1> tile_x_d_q4_0_acc_ct1(
  1504. sycl::range<1>(mmq_y * (WARP_SIZE / QI4_0) + mmq_y / QI4_0),
  1505. cgh);
  1506. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1507. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1508. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1509. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1510. cgh.parallel_for(
  1511. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1512. [=](sycl::nd_item<3> item_ct1) {
  1513. mul_mat_q4_0<need_check>(
  1514. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1515. nrows_dst, item_ct1,
  1516. tile_x_qs_q4_0_acc_ct1.get_pointer(),
  1517. tile_x_d_q4_0_acc_ct1.get_pointer(),
  1518. tile_y_qs_acc_ct1.get_pointer(),
  1519. tile_y_ds_acc_ct1.get_pointer());
  1520. });
  1521. });
  1522. }
  1523. }
  1524. }
  1525. catch (sycl::exception const &exc) {
  1526. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  1527. << ", line:" << __LINE__ << std::endl;
  1528. std::exit(1);
  1529. }
  1530. static void ggml_mul_mat_q4_1_q8_1_sycl(const void *vx, const void *vy,
  1531. float *dst, const int ncols_x,
  1532. const int nrows_x, const int ncols_y,
  1533. const int nrows_y, const int nrows_dst,
  1534. dpct::queue_ptr stream) try {
  1535. int id;
  1536. SYCL_CHECK(
  1537. CHECK_TRY_ERROR(id = get_current_device_id()));
  1538. const int compute_capability = ggml_sycl_info().devices[id].cc;
  1539. int mmq_x, mmq_y, nwarps;
  1540. if (compute_capability >= VER_GEN13) {
  1541. mmq_x = MMQ_X_Q4_1_RDNA2;
  1542. mmq_y = MMQ_Y_Q4_1_RDNA2;
  1543. nwarps = NWARPS_Q4_1_RDNA2;
  1544. } else if (compute_capability >= VER_GEN12) {
  1545. mmq_x = MMQ_X_Q4_1_RDNA1;
  1546. mmq_y = MMQ_Y_Q4_1_RDNA1;
  1547. nwarps = NWARPS_Q4_1_RDNA1;
  1548. } else if (compute_capability >= VER_GEN9) {
  1549. mmq_x = MMQ_X_Q4_1_AMPERE;
  1550. mmq_y = MMQ_Y_Q4_1_AMPERE;
  1551. nwarps = NWARPS_Q4_1_AMPERE;
  1552. } else if (compute_capability >= VER_4VEC) {
  1553. mmq_x = MMQ_X_Q4_1_PASCAL;
  1554. mmq_y = MMQ_Y_Q4_1_PASCAL;
  1555. nwarps = NWARPS_Q4_1_PASCAL;
  1556. } else {
  1557. GGML_ASSERT(false);
  1558. }
  1559. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1560. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1561. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  1562. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  1563. if (nrows_x % mmq_y == 0) {
  1564. const bool need_check = false;
  1565. /*
  1566. DPCT1049:22: The work-group size passed to the SYCL kernel may exceed
  1567. the limit. To get the device limit, query
  1568. info::device::max_work_group_size. Adjust the work-group size if needed.
  1569. */
  1570. {
  1571. dpct::has_capability_or_fail(stream->get_device(),
  1572. {sycl::aspect::fp16});
  1573. stream->submit([&](sycl::handler &cgh) {
  1574. sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
  1575. sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
  1576. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
  1577. sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
  1578. cgh);
  1579. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1580. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1581. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1582. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1583. cgh.parallel_for(
  1584. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1585. [=](sycl::nd_item<3> item_ct1) {
  1586. mul_mat_q4_1<need_check>(
  1587. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1588. nrows_dst, item_ct1,
  1589. tile_x_qs_q4_1_acc_ct1.get_pointer(),
  1590. tile_x_dm_q4_1_acc_ct1.get_pointer(),
  1591. tile_y_qs_acc_ct1.get_pointer(),
  1592. tile_y_ds_acc_ct1.get_pointer());
  1593. });
  1594. });
  1595. }
  1596. } else {
  1597. const bool need_check = true;
  1598. /*
  1599. DPCT1049:23: The work-group size passed to the SYCL kernel may exceed
  1600. the limit. To get the device limit, query
  1601. info::device::max_work_group_size. Adjust the work-group size if needed.
  1602. */
  1603. {
  1604. dpct::has_capability_or_fail(stream->get_device(),
  1605. {sycl::aspect::fp16});
  1606. stream->submit([&](sycl::handler &cgh) {
  1607. sycl::local_accessor<int, 1> tile_x_qs_q4_1_acc_ct1(
  1608. sycl::range<1>(mmq_y * (WARP_SIZE) + +mmq_y), cgh);
  1609. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_1_acc_ct1(
  1610. sycl::range<1>(mmq_y * (WARP_SIZE / QI4_1) + mmq_y / QI4_1),
  1611. cgh);
  1612. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1613. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1614. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1615. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1616. cgh.parallel_for(
  1617. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1618. [=](sycl::nd_item<3> item_ct1) {
  1619. mul_mat_q4_1<need_check>(
  1620. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1621. nrows_dst, item_ct1,
  1622. tile_x_qs_q4_1_acc_ct1.get_pointer(),
  1623. tile_x_dm_q4_1_acc_ct1.get_pointer(),
  1624. tile_y_qs_acc_ct1.get_pointer(),
  1625. tile_y_ds_acc_ct1.get_pointer());
  1626. });
  1627. });
  1628. }
  1629. }
  1630. }
  1631. catch (sycl::exception const &exc) {
  1632. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  1633. << ", line:" << __LINE__ << std::endl;
  1634. std::exit(1);
  1635. }
  1636. static void ggml_mul_mat_q5_0_q8_1_sycl(const void *vx, const void *vy,
  1637. float *dst, const int ncols_x,
  1638. const int nrows_x, const int ncols_y,
  1639. const int nrows_y, const int nrows_dst,
  1640. dpct::queue_ptr stream) try {
  1641. int id;
  1642. SYCL_CHECK(
  1643. CHECK_TRY_ERROR(id = get_current_device_id()));
  1644. const int compute_capability = ggml_sycl_info().devices[id].cc;
  1645. int mmq_x, mmq_y, nwarps;
  1646. if (compute_capability >= VER_GEN13) {
  1647. mmq_x = MMQ_X_Q5_0_RDNA2;
  1648. mmq_y = MMQ_Y_Q5_0_RDNA2;
  1649. nwarps = NWARPS_Q5_0_RDNA2;
  1650. } else if (compute_capability >= VER_GEN12) {
  1651. mmq_x = MMQ_X_Q5_0_RDNA1;
  1652. mmq_y = MMQ_Y_Q5_0_RDNA1;
  1653. nwarps = NWARPS_Q5_0_RDNA1;
  1654. } else if (compute_capability >= VER_GEN9) {
  1655. mmq_x = MMQ_X_Q5_0_AMPERE;
  1656. mmq_y = MMQ_Y_Q5_0_AMPERE;
  1657. nwarps = NWARPS_Q5_0_AMPERE;
  1658. } else if (compute_capability >= VER_4VEC) {
  1659. mmq_x = MMQ_X_Q5_0_PASCAL;
  1660. mmq_y = MMQ_Y_Q5_0_PASCAL;
  1661. nwarps = NWARPS_Q5_0_PASCAL;
  1662. } else {
  1663. GGML_ASSERT(false);
  1664. }
  1665. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1666. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1667. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  1668. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  1669. if (nrows_x % mmq_y == 0) {
  1670. const bool need_check = false;
  1671. /*
  1672. DPCT1049:24: The work-group size passed to the SYCL kernel may exceed
  1673. the limit. To get the device limit, query
  1674. info::device::max_work_group_size. Adjust the work-group size if needed.
  1675. */
  1676. {
  1677. dpct::has_capability_or_fail(stream->get_device(),
  1678. {sycl::aspect::fp16});
  1679. stream->submit([&](sycl::handler &cgh) {
  1680. sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
  1681. sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
  1682. sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
  1683. sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
  1684. cgh);
  1685. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1686. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1687. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1688. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1689. cgh.parallel_for(
  1690. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1691. [=](sycl::nd_item<3> item_ct1) {
  1692. mul_mat_q5_0<need_check>(
  1693. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1694. nrows_dst, item_ct1,
  1695. tile_x_ql_q5_0_acc_ct1.get_pointer(),
  1696. tile_x_d_q5_0_acc_ct1.get_pointer(),
  1697. tile_y_qs_acc_ct1.get_pointer(),
  1698. tile_y_ds_acc_ct1.get_pointer());
  1699. });
  1700. });
  1701. }
  1702. } else {
  1703. const bool need_check = true;
  1704. /*
  1705. DPCT1049:25: The work-group size passed to the SYCL kernel may exceed
  1706. the limit. To get the device limit, query
  1707. info::device::max_work_group_size. Adjust the work-group size if needed.
  1708. */
  1709. {
  1710. dpct::has_capability_or_fail(stream->get_device(),
  1711. {sycl::aspect::fp16});
  1712. stream->submit([&](sycl::handler &cgh) {
  1713. sycl::local_accessor<int, 1> tile_x_ql_q5_0_acc_ct1(
  1714. sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
  1715. sycl::local_accessor<float, 1> tile_x_d_q5_0_acc_ct1(
  1716. sycl::range<1>(mmq_y * (WARP_SIZE / QI5_0) + mmq_y / QI5_0),
  1717. cgh);
  1718. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1719. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1720. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1721. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1722. cgh.parallel_for(
  1723. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1724. [=](sycl::nd_item<3> item_ct1) {
  1725. mul_mat_q5_0<need_check>(
  1726. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1727. nrows_dst, item_ct1,
  1728. tile_x_ql_q5_0_acc_ct1.get_pointer(),
  1729. tile_x_d_q5_0_acc_ct1.get_pointer(),
  1730. tile_y_qs_acc_ct1.get_pointer(),
  1731. tile_y_ds_acc_ct1.get_pointer());
  1732. });
  1733. });
  1734. }
  1735. }
  1736. }
  1737. catch (sycl::exception const &exc) {
  1738. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  1739. << ", line:" << __LINE__ << std::endl;
  1740. std::exit(1);
  1741. }
  1742. static void ggml_mul_mat_q5_1_q8_1_sycl(const void *vx, const void *vy,
  1743. float *dst, const int ncols_x,
  1744. const int nrows_x, const int ncols_y,
  1745. const int nrows_y, const int nrows_dst,
  1746. dpct::queue_ptr stream) try {
  1747. int id;
  1748. SYCL_CHECK(
  1749. CHECK_TRY_ERROR(id = get_current_device_id()));
  1750. const int compute_capability = ggml_sycl_info().devices[id].cc;
  1751. int mmq_x, mmq_y, nwarps;
  1752. if (compute_capability >= VER_GEN13) {
  1753. mmq_x = MMQ_X_Q5_1_RDNA2;
  1754. mmq_y = MMQ_Y_Q5_1_RDNA2;
  1755. nwarps = NWARPS_Q5_1_RDNA2;
  1756. } else if (compute_capability >= VER_GEN12) {
  1757. mmq_x = MMQ_X_Q5_1_RDNA1;
  1758. mmq_y = MMQ_Y_Q5_1_RDNA1;
  1759. nwarps = NWARPS_Q5_1_RDNA1;
  1760. } else if (compute_capability >= VER_GEN9) {
  1761. mmq_x = MMQ_X_Q5_1_AMPERE;
  1762. mmq_y = MMQ_Y_Q5_1_AMPERE;
  1763. nwarps = NWARPS_Q5_1_AMPERE;
  1764. } else if (compute_capability >= VER_4VEC) {
  1765. mmq_x = MMQ_X_Q5_1_PASCAL;
  1766. mmq_y = MMQ_Y_Q5_1_PASCAL;
  1767. nwarps = NWARPS_Q5_1_PASCAL;
  1768. } else {
  1769. GGML_ASSERT(false);
  1770. }
  1771. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1772. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1773. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  1774. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  1775. if (nrows_x % mmq_y == 0) {
  1776. const bool need_check = false;
  1777. /*
  1778. DPCT1049:26: The work-group size passed to the SYCL kernel may exceed
  1779. the limit. To get the device limit, query
  1780. info::device::max_work_group_size. Adjust the work-group size if needed.
  1781. */
  1782. {
  1783. dpct::has_capability_or_fail(stream->get_device(),
  1784. {sycl::aspect::fp16});
  1785. stream->submit([&](sycl::handler &cgh) {
  1786. sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
  1787. sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
  1788. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
  1789. sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
  1790. cgh);
  1791. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1792. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1793. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1794. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1795. cgh.parallel_for(
  1796. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1797. [=](sycl::nd_item<3> item_ct1) {
  1798. mul_mat_q5_1<need_check>(
  1799. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1800. nrows_dst, item_ct1,
  1801. tile_x_ql_q5_1_acc_ct1.get_pointer(),
  1802. tile_x_dm_q5_1_acc_ct1.get_pointer(),
  1803. tile_y_qs_acc_ct1.get_pointer(),
  1804. tile_y_ds_acc_ct1.get_pointer());
  1805. });
  1806. });
  1807. }
  1808. } else {
  1809. const bool need_check = true;
  1810. /*
  1811. DPCT1049:27: The work-group size passed to the SYCL kernel may exceed
  1812. the limit. To get the device limit, query
  1813. info::device::max_work_group_size. Adjust the work-group size if needed.
  1814. */
  1815. {
  1816. dpct::has_capability_or_fail(stream->get_device(),
  1817. {sycl::aspect::fp16});
  1818. stream->submit([&](sycl::handler &cgh) {
  1819. sycl::local_accessor<int, 1> tile_x_ql_q5_1_acc_ct1(
  1820. sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
  1821. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_1_acc_ct1(
  1822. sycl::range<1>(mmq_y * (WARP_SIZE / QI5_1) + mmq_y / QI5_1),
  1823. cgh);
  1824. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1825. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1826. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1827. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1828. cgh.parallel_for(
  1829. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1830. [=](sycl::nd_item<3> item_ct1) {
  1831. mul_mat_q5_1<need_check>(
  1832. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1833. nrows_dst, item_ct1,
  1834. tile_x_ql_q5_1_acc_ct1.get_pointer(),
  1835. tile_x_dm_q5_1_acc_ct1.get_pointer(),
  1836. tile_y_qs_acc_ct1.get_pointer(),
  1837. tile_y_ds_acc_ct1.get_pointer());
  1838. });
  1839. });
  1840. }
  1841. }
  1842. }
  1843. catch (sycl::exception const &exc) {
  1844. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  1845. << ", line:" << __LINE__ << std::endl;
  1846. std::exit(1);
  1847. }
  1848. static void ggml_mul_mat_q8_0_q8_1_sycl(const void *vx, const void *vy,
  1849. float *dst, const int ncols_x,
  1850. const int nrows_x, const int ncols_y,
  1851. const int nrows_y, const int nrows_dst,
  1852. dpct::queue_ptr stream) try {
  1853. int id;
  1854. SYCL_CHECK(
  1855. CHECK_TRY_ERROR(id = get_current_device_id()));
  1856. const int compute_capability = ggml_sycl_info().devices[id].cc;
  1857. int mmq_x, mmq_y, nwarps;
  1858. if (compute_capability >= VER_GEN13) {
  1859. mmq_x = MMQ_X_Q8_0_RDNA2;
  1860. mmq_y = MMQ_Y_Q8_0_RDNA2;
  1861. nwarps = NWARPS_Q8_0_RDNA2;
  1862. } else if (compute_capability >= VER_GEN12) {
  1863. mmq_x = MMQ_X_Q8_0_RDNA1;
  1864. mmq_y = MMQ_Y_Q8_0_RDNA1;
  1865. nwarps = NWARPS_Q8_0_RDNA1;
  1866. } else if (compute_capability >= VER_GEN9) {
  1867. mmq_x = MMQ_X_Q8_0_AMPERE;
  1868. mmq_y = MMQ_Y_Q8_0_AMPERE;
  1869. nwarps = NWARPS_Q8_0_AMPERE;
  1870. } else if (compute_capability >= VER_4VEC) {
  1871. mmq_x = MMQ_X_Q8_0_PASCAL;
  1872. mmq_y = MMQ_Y_Q8_0_PASCAL;
  1873. nwarps = NWARPS_Q8_0_PASCAL;
  1874. } else {
  1875. GGML_ASSERT(false);
  1876. }
  1877. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1878. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1879. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  1880. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  1881. if (nrows_x % mmq_y == 0) {
  1882. const bool need_check = false;
  1883. /*
  1884. DPCT1049:28: The work-group size passed to the SYCL kernel may exceed
  1885. the limit. To get the device limit, query
  1886. info::device::max_work_group_size. Adjust the work-group size if needed.
  1887. */
  1888. {
  1889. dpct::has_capability_or_fail(stream->get_device(),
  1890. {sycl::aspect::fp16});
  1891. stream->submit([&](sycl::handler &cgh) {
  1892. sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
  1893. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  1894. sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
  1895. sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
  1896. cgh);
  1897. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1898. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1899. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1900. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1901. cgh.parallel_for(
  1902. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1903. [=](sycl::nd_item<3> item_ct1) {
  1904. mul_mat_q8_0<need_check>(
  1905. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1906. nrows_dst, item_ct1,
  1907. tile_x_qs_q8_0_acc_ct1.get_pointer(),
  1908. tile_x_d_q8_0_acc_ct1.get_pointer(),
  1909. tile_y_qs_acc_ct1.get_pointer(),
  1910. tile_y_ds_acc_ct1.get_pointer());
  1911. });
  1912. });
  1913. }
  1914. } else {
  1915. const bool need_check = true;
  1916. /*
  1917. DPCT1049:29: The work-group size passed to the SYCL kernel may exceed
  1918. the limit. To get the device limit, query
  1919. info::device::max_work_group_size. Adjust the work-group size if needed.
  1920. */
  1921. {
  1922. dpct::has_capability_or_fail(stream->get_device(),
  1923. {sycl::aspect::fp16});
  1924. stream->submit([&](sycl::handler &cgh) {
  1925. sycl::local_accessor<int, 1> tile_x_qs_q8_0_acc_ct1(
  1926. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  1927. sycl::local_accessor<float, 1> tile_x_d_q8_0_acc_ct1(
  1928. sycl::range<1>(mmq_y * (WARP_SIZE / QI8_0) + mmq_y / QI8_0),
  1929. cgh);
  1930. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  1931. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  1932. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  1933. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  1934. cgh.parallel_for(
  1935. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  1936. [=](sycl::nd_item<3> item_ct1) {
  1937. mul_mat_q8_0<need_check>(
  1938. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  1939. nrows_dst, item_ct1,
  1940. tile_x_qs_q8_0_acc_ct1.get_pointer(),
  1941. tile_x_d_q8_0_acc_ct1.get_pointer(),
  1942. tile_y_qs_acc_ct1.get_pointer(),
  1943. tile_y_ds_acc_ct1.get_pointer());
  1944. });
  1945. });
  1946. }
  1947. }
  1948. }
  1949. catch (sycl::exception const &exc) {
  1950. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  1951. << ", line:" << __LINE__ << std::endl;
  1952. std::exit(1);
  1953. }
  1954. static void ggml_mul_mat_q2_K_q8_1_sycl(const void *vx, const void *vy,
  1955. float *dst, const int ncols_x,
  1956. const int nrows_x, const int ncols_y,
  1957. const int nrows_y, const int nrows_dst,
  1958. dpct::queue_ptr stream) try {
  1959. int id;
  1960. SYCL_CHECK(
  1961. CHECK_TRY_ERROR(id = get_current_device_id()));
  1962. const int compute_capability = ggml_sycl_info().devices[id].cc;
  1963. int mmq_x, mmq_y, nwarps;
  1964. if (compute_capability >= VER_GEN13) {
  1965. mmq_x = MMQ_X_Q2_K_RDNA2;
  1966. mmq_y = MMQ_Y_Q2_K_RDNA2;
  1967. nwarps = NWARPS_Q2_K_RDNA2;
  1968. } else if (compute_capability >= VER_GEN12) {
  1969. mmq_x = MMQ_X_Q2_K_RDNA1;
  1970. mmq_y = MMQ_Y_Q2_K_RDNA1;
  1971. nwarps = NWARPS_Q2_K_RDNA1;
  1972. } else if (compute_capability >= VER_GEN9) {
  1973. mmq_x = MMQ_X_Q2_K_AMPERE;
  1974. mmq_y = MMQ_Y_Q2_K_AMPERE;
  1975. nwarps = NWARPS_Q2_K_AMPERE;
  1976. } else if (compute_capability >= VER_4VEC) {
  1977. mmq_x = MMQ_X_Q2_K_PASCAL;
  1978. mmq_y = MMQ_Y_Q2_K_PASCAL;
  1979. nwarps = NWARPS_Q2_K_PASCAL;
  1980. } else {
  1981. GGML_ASSERT(false);
  1982. }
  1983. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  1984. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  1985. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  1986. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  1987. if (nrows_x % mmq_y == 0) {
  1988. const bool need_check = false;
  1989. /*
  1990. DPCT1049:30: The work-group size passed to the SYCL kernel may exceed
  1991. the limit. To get the device limit, query
  1992. info::device::max_work_group_size. Adjust the work-group size if needed.
  1993. */
  1994. {
  1995. dpct::has_capability_or_fail(stream->get_device(),
  1996. {sycl::aspect::fp16});
  1997. stream->submit([&](sycl::handler &cgh) {
  1998. sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
  1999. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  2000. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
  2001. sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
  2002. cgh);
  2003. sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
  2004. sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
  2005. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2006. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2007. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2008. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2009. cgh.parallel_for(
  2010. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2011. [=](sycl::nd_item<3> item_ct1) {
  2012. mul_mat_q2_K<need_check>(
  2013. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2014. nrows_dst, item_ct1,
  2015. tile_x_ql_q2_K_acc_ct1.get_pointer(),
  2016. tile_x_dm_q2_K_acc_ct1.get_pointer(),
  2017. tile_x_sc_q2_K_acc_ct1.get_pointer(),
  2018. tile_y_qs_acc_ct1.get_pointer(),
  2019. tile_y_ds_acc_ct1.get_pointer());
  2020. });
  2021. });
  2022. }
  2023. } else {
  2024. const bool need_check = true;
  2025. /*
  2026. DPCT1049:31: The work-group size passed to the SYCL kernel may exceed
  2027. the limit. To get the device limit, query
  2028. info::device::max_work_group_size. Adjust the work-group size if needed.
  2029. */
  2030. {
  2031. dpct::has_capability_or_fail(stream->get_device(),
  2032. {sycl::aspect::fp16});
  2033. stream->submit([&](sycl::handler &cgh) {
  2034. sycl::local_accessor<int, 1> tile_x_ql_q2_K_acc_ct1(
  2035. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  2036. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q2_K_acc_ct1(
  2037. sycl::range<1>(mmq_y * (WARP_SIZE / QI2_K) + mmq_y / QI2_K),
  2038. cgh);
  2039. sycl::local_accessor<int, 1> tile_x_sc_q2_K_acc_ct1(
  2040. sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
  2041. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2042. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2043. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2044. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2045. cgh.parallel_for(
  2046. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2047. [=](sycl::nd_item<3> item_ct1) {
  2048. mul_mat_q2_K<need_check>(
  2049. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2050. nrows_dst, item_ct1,
  2051. tile_x_ql_q2_K_acc_ct1.get_pointer(),
  2052. tile_x_dm_q2_K_acc_ct1.get_pointer(),
  2053. tile_x_sc_q2_K_acc_ct1.get_pointer(),
  2054. tile_y_qs_acc_ct1.get_pointer(),
  2055. tile_y_ds_acc_ct1.get_pointer());
  2056. });
  2057. });
  2058. }
  2059. }
  2060. }
  2061. catch (sycl::exception const &exc) {
  2062. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  2063. << ", line:" << __LINE__ << std::endl;
  2064. std::exit(1);
  2065. }
  2066. static void ggml_mul_mat_q3_K_q8_1_sycl(const void *vx, const void *vy,
  2067. float *dst, const int ncols_x,
  2068. const int nrows_x, const int ncols_y,
  2069. const int nrows_y, const int nrows_dst,
  2070. dpct::queue_ptr stream) try {
  2071. #if QK_K == 256
  2072. int id;
  2073. SYCL_CHECK(
  2074. CHECK_TRY_ERROR(id = get_current_device_id()));
  2075. const int compute_capability = ggml_sycl_info().devices[id].cc;
  2076. int mmq_x, mmq_y, nwarps;
  2077. if (compute_capability >= VER_GEN13) {
  2078. mmq_x = MMQ_X_Q3_K_RDNA2;
  2079. mmq_y = MMQ_Y_Q3_K_RDNA2;
  2080. nwarps = NWARPS_Q3_K_RDNA2;
  2081. } else if (compute_capability >= VER_GEN12) {
  2082. mmq_x = MMQ_X_Q3_K_RDNA1;
  2083. mmq_y = MMQ_Y_Q3_K_RDNA1;
  2084. nwarps = NWARPS_Q3_K_RDNA1;
  2085. } else if (compute_capability >= VER_GEN9) {
  2086. mmq_x = MMQ_X_Q3_K_AMPERE;
  2087. mmq_y = MMQ_Y_Q3_K_AMPERE;
  2088. nwarps = NWARPS_Q3_K_AMPERE;
  2089. } else if (compute_capability >= VER_4VEC) {
  2090. mmq_x = MMQ_X_Q3_K_PASCAL;
  2091. mmq_y = MMQ_Y_Q3_K_PASCAL;
  2092. nwarps = NWARPS_Q3_K_PASCAL;
  2093. } else {
  2094. GGML_ASSERT(false);
  2095. }
  2096. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  2097. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  2098. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  2099. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  2100. if (nrows_x % mmq_y == 0) {
  2101. const bool need_check = false;
  2102. /*
  2103. DPCT1049:32: The work-group size passed to the SYCL kernel may exceed
  2104. the limit. To get the device limit, query
  2105. info::device::max_work_group_size. Adjust the work-group size if needed.
  2106. */
  2107. {
  2108. dpct::has_capability_or_fail(stream->get_device(),
  2109. {sycl::aspect::fp16});
  2110. stream->submit([&](sycl::handler &cgh) {
  2111. sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
  2112. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  2113. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
  2114. sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
  2115. cgh);
  2116. sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
  2117. sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
  2118. sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
  2119. sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
  2120. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2121. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2122. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2123. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2124. cgh.parallel_for(
  2125. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2126. [=](sycl::nd_item<3> item_ct1) {
  2127. mul_mat_q3_K<need_check>(
  2128. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2129. nrows_dst, item_ct1,
  2130. tile_x_ql_q3_K_acc_ct1.get_pointer(),
  2131. tile_x_dm_q3_K_acc_ct1.get_pointer(),
  2132. tile_x_qh_q3_K_acc_ct1.get_pointer(),
  2133. tile_x_sc_q3_K_acc_ct1.get_pointer(),
  2134. tile_y_qs_acc_ct1.get_pointer(),
  2135. tile_y_ds_acc_ct1.get_pointer());
  2136. });
  2137. });
  2138. }
  2139. } else {
  2140. const bool need_check = true;
  2141. /*
  2142. DPCT1049:33: The work-group size passed to the SYCL kernel may exceed
  2143. the limit. To get the device limit, query
  2144. info::device::max_work_group_size. Adjust the work-group size if needed.
  2145. */
  2146. {
  2147. dpct::has_capability_or_fail(stream->get_device(),
  2148. {sycl::aspect::fp16});
  2149. stream->submit([&](sycl::handler &cgh) {
  2150. sycl::local_accessor<int, 1> tile_x_ql_q3_K_acc_ct1(
  2151. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  2152. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q3_K_acc_ct1(
  2153. sycl::range<1>(mmq_y * (WARP_SIZE / QI3_K) + mmq_y / QI3_K),
  2154. cgh);
  2155. sycl::local_accessor<int, 1> tile_x_qh_q3_K_acc_ct1(
  2156. sycl::range<1>(mmq_y * (WARP_SIZE / 2) + mmq_y / 2), cgh);
  2157. sycl::local_accessor<int, 1> tile_x_sc_q3_K_acc_ct1(
  2158. sycl::range<1>(mmq_y * (WARP_SIZE / 4) + mmq_y / 4), cgh);
  2159. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2160. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2161. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2162. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2163. cgh.parallel_for(
  2164. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2165. [=](sycl::nd_item<3> item_ct1) {
  2166. mul_mat_q3_K<need_check>(
  2167. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2168. nrows_dst, item_ct1,
  2169. tile_x_ql_q3_K_acc_ct1.get_pointer(),
  2170. tile_x_dm_q3_K_acc_ct1.get_pointer(),
  2171. tile_x_qh_q3_K_acc_ct1.get_pointer(),
  2172. tile_x_sc_q3_K_acc_ct1.get_pointer(),
  2173. tile_y_qs_acc_ct1.get_pointer(),
  2174. tile_y_ds_acc_ct1.get_pointer());
  2175. });
  2176. });
  2177. }
  2178. }
  2179. #endif
  2180. }
  2181. catch (sycl::exception const &exc) {
  2182. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  2183. << ", line:" << __LINE__ << std::endl;
  2184. std::exit(1);
  2185. }
  2186. static void ggml_mul_mat_q4_K_q8_1_sycl(const void *vx, const void *vy,
  2187. float *dst, const int ncols_x,
  2188. const int nrows_x, const int ncols_y,
  2189. const int nrows_y, const int nrows_dst,
  2190. dpct::queue_ptr stream) try {
  2191. int id;
  2192. SYCL_CHECK(
  2193. CHECK_TRY_ERROR(id = get_current_device_id()));
  2194. const int compute_capability = ggml_sycl_info().devices[id].cc;
  2195. int mmq_x, mmq_y, nwarps;
  2196. if (compute_capability >= VER_GEN13) {
  2197. mmq_x = MMQ_X_Q4_K_RDNA2;
  2198. mmq_y = MMQ_Y_Q4_K_RDNA2;
  2199. nwarps = NWARPS_Q4_K_RDNA2;
  2200. } else if (compute_capability >= VER_GEN12) {
  2201. mmq_x = MMQ_X_Q4_K_RDNA1;
  2202. mmq_y = MMQ_Y_Q4_K_RDNA1;
  2203. nwarps = NWARPS_Q4_K_RDNA1;
  2204. } else if (compute_capability >= VER_GEN9) {
  2205. mmq_x = MMQ_X_Q4_K_AMPERE;
  2206. mmq_y = MMQ_Y_Q4_K_AMPERE;
  2207. nwarps = NWARPS_Q4_K_AMPERE;
  2208. } else if (compute_capability >= VER_4VEC) {
  2209. mmq_x = MMQ_X_Q4_K_PASCAL;
  2210. mmq_y = MMQ_Y_Q4_K_PASCAL;
  2211. nwarps = NWARPS_Q4_K_PASCAL;
  2212. } else {
  2213. GGML_ASSERT(false);
  2214. }
  2215. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  2216. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  2217. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  2218. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  2219. if (nrows_x % mmq_y == 0) {
  2220. const bool need_check = false;
  2221. /*
  2222. DPCT1049:34: The work-group size passed to the SYCL kernel may exceed
  2223. the limit. To get the device limit, query
  2224. info::device::max_work_group_size. Adjust the work-group size if needed.
  2225. */
  2226. {
  2227. dpct::has_capability_or_fail(stream->get_device(),
  2228. {sycl::aspect::fp16});
  2229. stream->submit([&](sycl::handler &cgh) {
  2230. sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
  2231. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  2232. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
  2233. sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
  2234. cgh);
  2235. sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
  2236. sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
  2237. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2238. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2239. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2240. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2241. cgh.parallel_for(
  2242. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2243. [=](sycl::nd_item<3> item_ct1) {
  2244. mul_mat_q4_K<need_check>(
  2245. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2246. nrows_dst, item_ct1,
  2247. tile_x_ql_q4_K_acc_ct1.get_pointer(),
  2248. tile_x_dm_q4_K_acc_ct1.get_pointer(),
  2249. tile_x_sc_q4_K_acc_ct1.get_pointer(),
  2250. tile_y_qs_acc_ct1.get_pointer(),
  2251. tile_y_ds_acc_ct1.get_pointer());
  2252. });
  2253. });
  2254. }
  2255. } else {
  2256. const bool need_check = true;
  2257. /*
  2258. DPCT1049:35: The work-group size passed to the SYCL kernel may exceed
  2259. the limit. To get the device limit, query
  2260. info::device::max_work_group_size. Adjust the work-group size if needed.
  2261. */
  2262. {
  2263. dpct::has_capability_or_fail(stream->get_device(),
  2264. {sycl::aspect::fp16});
  2265. stream->submit([&](sycl::handler &cgh) {
  2266. sycl::local_accessor<int, 1> tile_x_ql_q4_K_acc_ct1(
  2267. sycl::range<1>(mmq_y * (WARP_SIZE) + mmq_y), cgh);
  2268. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q4_K_acc_ct1(
  2269. sycl::range<1>(mmq_y * (WARP_SIZE / QI4_K) + mmq_y / QI4_K),
  2270. cgh);
  2271. sycl::local_accessor<int, 1> tile_x_sc_q4_K_acc_ct1(
  2272. sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
  2273. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2274. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2275. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2276. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2277. cgh.parallel_for(
  2278. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2279. [=](sycl::nd_item<3> item_ct1) {
  2280. mul_mat_q4_K<need_check>(
  2281. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2282. nrows_dst, item_ct1,
  2283. tile_x_ql_q4_K_acc_ct1.get_pointer(),
  2284. tile_x_dm_q4_K_acc_ct1.get_pointer(),
  2285. tile_x_sc_q4_K_acc_ct1.get_pointer(),
  2286. tile_y_qs_acc_ct1.get_pointer(),
  2287. tile_y_ds_acc_ct1.get_pointer());
  2288. });
  2289. });
  2290. }
  2291. }
  2292. }
  2293. catch (sycl::exception const &exc) {
  2294. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  2295. << ", line:" << __LINE__ << std::endl;
  2296. std::exit(1);
  2297. }
  2298. static void ggml_mul_mat_q5_K_q8_1_sycl(const void *vx, const void *vy,
  2299. float *dst, const int ncols_x,
  2300. const int nrows_x, const int ncols_y,
  2301. const int nrows_y, const int nrows_dst,
  2302. dpct::queue_ptr stream) try {
  2303. int id;
  2304. SYCL_CHECK(
  2305. CHECK_TRY_ERROR(id = get_current_device_id()));
  2306. const int compute_capability = ggml_sycl_info().devices[id].cc;
  2307. int mmq_x, mmq_y, nwarps;
  2308. if (compute_capability >= VER_GEN13) {
  2309. mmq_x = MMQ_X_Q5_K_RDNA2;
  2310. mmq_y = MMQ_Y_Q5_K_RDNA2;
  2311. nwarps = NWARPS_Q5_K_RDNA2;
  2312. } else if (compute_capability >= VER_GEN12) {
  2313. mmq_x = MMQ_X_Q5_K_RDNA1;
  2314. mmq_y = MMQ_Y_Q5_K_RDNA1;
  2315. nwarps = NWARPS_Q5_K_RDNA1;
  2316. } else if (compute_capability >= VER_GEN9) {
  2317. mmq_x = MMQ_X_Q5_K_AMPERE;
  2318. mmq_y = MMQ_Y_Q5_K_AMPERE;
  2319. nwarps = NWARPS_Q5_K_AMPERE;
  2320. } else if (compute_capability >= VER_4VEC) {
  2321. mmq_x = MMQ_X_Q5_K_PASCAL;
  2322. mmq_y = MMQ_Y_Q5_K_PASCAL;
  2323. nwarps = NWARPS_Q5_K_PASCAL;
  2324. } else {
  2325. GGML_ASSERT(false);
  2326. }
  2327. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  2328. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  2329. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  2330. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  2331. if (nrows_x % mmq_y == 0) {
  2332. const bool need_check = false;
  2333. /*
  2334. DPCT1049:36: The work-group size passed to the SYCL kernel may exceed
  2335. the limit. To get the device limit, query
  2336. info::device::max_work_group_size. Adjust the work-group size if needed.
  2337. */
  2338. {
  2339. dpct::has_capability_or_fail(stream->get_device(),
  2340. {sycl::aspect::fp16});
  2341. stream->submit([&](sycl::handler &cgh) {
  2342. sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
  2343. sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
  2344. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
  2345. sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
  2346. cgh);
  2347. sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
  2348. sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
  2349. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2350. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2351. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2352. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2353. cgh.parallel_for(
  2354. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2355. [=](sycl::nd_item<3> item_ct1) {
  2356. mul_mat_q5_K<need_check>(
  2357. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2358. nrows_dst, item_ct1,
  2359. tile_x_ql_q5_K_acc_ct1.get_pointer(),
  2360. tile_x_dm_q5_K_acc_ct1.get_pointer(),
  2361. tile_x_sc_q5_K_acc_ct1.get_pointer(),
  2362. tile_y_qs_acc_ct1.get_pointer(),
  2363. tile_y_ds_acc_ct1.get_pointer());
  2364. });
  2365. });
  2366. }
  2367. } else {
  2368. const bool need_check = true;
  2369. /*
  2370. DPCT1049:37: The work-group size passed to the SYCL kernel may exceed
  2371. the limit. To get the device limit, query
  2372. info::device::max_work_group_size. Adjust the work-group size if needed.
  2373. */
  2374. {
  2375. dpct::has_capability_or_fail(stream->get_device(),
  2376. {sycl::aspect::fp16});
  2377. stream->submit([&](sycl::handler &cgh) {
  2378. sycl::local_accessor<int, 1> tile_x_ql_q5_K_acc_ct1(
  2379. sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
  2380. sycl::local_accessor<sycl::half2, 1> tile_x_dm_q5_K_acc_ct1(
  2381. sycl::range<1>(mmq_y * (WARP_SIZE / QI5_K) + mmq_y / QI5_K),
  2382. cgh);
  2383. sycl::local_accessor<int, 1> tile_x_sc_q5_K_acc_ct1(
  2384. sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
  2385. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2386. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2387. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2388. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2389. cgh.parallel_for(
  2390. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2391. [=](sycl::nd_item<3> item_ct1) {
  2392. mul_mat_q5_K<need_check>(
  2393. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2394. nrows_dst, item_ct1,
  2395. tile_x_ql_q5_K_acc_ct1.get_pointer(),
  2396. tile_x_dm_q5_K_acc_ct1.get_pointer(),
  2397. tile_x_sc_q5_K_acc_ct1.get_pointer(),
  2398. tile_y_qs_acc_ct1.get_pointer(),
  2399. tile_y_ds_acc_ct1.get_pointer());
  2400. });
  2401. });
  2402. }
  2403. }
  2404. }
  2405. catch (sycl::exception const &exc) {
  2406. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  2407. << ", line:" << __LINE__ << std::endl;
  2408. std::exit(1);
  2409. }
  2410. static void ggml_mul_mat_q6_K_q8_1_sycl(const void *vx, const void *vy,
  2411. float *dst, const int ncols_x,
  2412. const int nrows_x, const int ncols_y,
  2413. const int nrows_y, const int nrows_dst,
  2414. dpct::queue_ptr stream) try {
  2415. int id;
  2416. SYCL_CHECK(
  2417. CHECK_TRY_ERROR(id = get_current_device_id()));
  2418. const int compute_capability = ggml_sycl_info().devices[id].cc;
  2419. int mmq_x, mmq_y, nwarps;
  2420. if (compute_capability >= VER_GEN13) {
  2421. mmq_x = MMQ_X_Q6_K_RDNA2;
  2422. mmq_y = MMQ_Y_Q6_K_RDNA2;
  2423. nwarps = NWARPS_Q6_K_RDNA2;
  2424. } else if (compute_capability >= VER_GEN12) {
  2425. mmq_x = MMQ_X_Q6_K_RDNA1;
  2426. mmq_y = MMQ_Y_Q6_K_RDNA1;
  2427. nwarps = NWARPS_Q6_K_RDNA1;
  2428. } else if (compute_capability >= VER_GEN9) {
  2429. mmq_x = MMQ_X_Q6_K_AMPERE;
  2430. mmq_y = MMQ_Y_Q6_K_AMPERE;
  2431. nwarps = NWARPS_Q6_K_AMPERE;
  2432. } else if (compute_capability >= VER_4VEC) {
  2433. mmq_x = MMQ_X_Q6_K_PASCAL;
  2434. mmq_y = MMQ_Y_Q6_K_PASCAL;
  2435. nwarps = NWARPS_Q6_K_PASCAL;
  2436. } else {
  2437. GGML_ASSERT(false);
  2438. }
  2439. const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
  2440. const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
  2441. const sycl::range<3> block_nums(1, block_num_y, block_num_x);
  2442. const sycl::range<3> block_dims(1, nwarps, WARP_SIZE);
  2443. if (nrows_x % mmq_y == 0) {
  2444. const bool need_check = false;
  2445. /*
  2446. DPCT1049:38: The work-group size passed to the SYCL kernel may exceed
  2447. the limit. To get the device limit, query
  2448. info::device::max_work_group_size. Adjust the work-group size if needed.
  2449. */
  2450. {
  2451. dpct::has_capability_or_fail(stream->get_device(),
  2452. {sycl::aspect::fp16});
  2453. stream->submit([&](sycl::handler &cgh) {
  2454. sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
  2455. sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
  2456. sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
  2457. sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
  2458. cgh);
  2459. sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
  2460. sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
  2461. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2462. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2463. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2464. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2465. cgh.parallel_for(
  2466. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2467. [=](sycl::nd_item<3> item_ct1) {
  2468. mul_mat_q6_K<need_check>(
  2469. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2470. nrows_dst, item_ct1,
  2471. tile_x_ql_acc_ct1.get_pointer(),
  2472. tile_x_dm_acc_ct1.get_pointer(),
  2473. tile_x_sc_acc_ct1.get_pointer(),
  2474. tile_y_qs_acc_ct1.get_pointer(),
  2475. tile_y_ds_acc_ct1.get_pointer());
  2476. });
  2477. });
  2478. }
  2479. } else {
  2480. const bool need_check = true;
  2481. /*
  2482. DPCT1049:39: The work-group size passed to the SYCL kernel may exceed
  2483. the limit. To get the device limit, query
  2484. info::device::max_work_group_size. Adjust the work-group size if needed.
  2485. */
  2486. {
  2487. dpct::has_capability_or_fail(stream->get_device(),
  2488. {sycl::aspect::fp16});
  2489. stream->submit([&](sycl::handler &cgh) {
  2490. sycl::local_accessor<int, 1> tile_x_ql_acc_ct1(
  2491. sycl::range<1>(mmq_y * (2 * WARP_SIZE) + mmq_y), cgh);
  2492. sycl::local_accessor<sycl::half2, 1> tile_x_dm_acc_ct1(
  2493. sycl::range<1>(mmq_y * (WARP_SIZE / QI6_K) + mmq_y / QI6_K),
  2494. cgh);
  2495. sycl::local_accessor<int, 1> tile_x_sc_acc_ct1(
  2496. sycl::range<1>(mmq_y * (WARP_SIZE / 8) + mmq_y / 8), cgh);
  2497. sycl::local_accessor<int, 1> tile_y_qs_acc_ct1(
  2498. sycl::range<1>(mmq_x * WARP_SIZE), cgh);
  2499. sycl::local_accessor<sycl::half2, 1> tile_y_ds_acc_ct1(
  2500. sycl::range<1>(mmq_x * WARP_SIZE / QI8_1), cgh);
  2501. cgh.parallel_for(
  2502. sycl::nd_range<3>(block_nums * block_dims, block_dims),
  2503. [=](sycl::nd_item<3> item_ct1) {
  2504. mul_mat_q6_K<need_check>(
  2505. vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y,
  2506. nrows_dst, item_ct1,
  2507. tile_x_ql_acc_ct1.get_pointer(),
  2508. tile_x_dm_acc_ct1.get_pointer(),
  2509. tile_x_sc_acc_ct1.get_pointer(),
  2510. tile_y_qs_acc_ct1.get_pointer(),
  2511. tile_y_ds_acc_ct1.get_pointer());
  2512. });
  2513. });
  2514. }
  2515. }
  2516. }
  2517. catch (sycl::exception const &exc) {
  2518. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  2519. << ", line:" << __LINE__ << std::endl;
  2520. std::exit(1);
  2521. }
  2522. void ggml_sycl_op_mul_mat_q(
  2523. ggml_backend_sycl_context & ctx,
  2524. const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
  2525. const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
  2526. float *dst_dd_i, const int64_t row_low, const int64_t row_high,
  2527. const int64_t src1_ncols, const int64_t src1_padded_row_size,
  2528. const dpct::queue_ptr &stream) try {
  2529. const int64_t ne00 = src0->ne[0];
  2530. const int64_t ne10 = src1->ne[0];
  2531. GGML_ASSERT(ne10 % QK8_1 == 0);
  2532. const int64_t ne0 = dst->ne[0];
  2533. const int64_t row_diff = row_high - row_low;
  2534. int device_id;
  2535. SYCL_CHECK(
  2536. CHECK_TRY_ERROR(device_id = get_current_device_id()));
  2537. // the main device has a larger memory buffer to hold the results from all GPUs
  2538. // nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
  2539. const int64_t nrows_dst = device_id == ctx.device ? ne0 : row_diff;
  2540. switch (src0->type) {
  2541. case GGML_TYPE_Q4_0:
  2542. ggml_mul_mat_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2543. break;
  2544. case GGML_TYPE_Q4_1:
  2545. ggml_mul_mat_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2546. break;
  2547. case GGML_TYPE_Q5_0:
  2548. ggml_mul_mat_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2549. break;
  2550. case GGML_TYPE_Q5_1:
  2551. ggml_mul_mat_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2552. break;
  2553. case GGML_TYPE_Q8_0:
  2554. ggml_mul_mat_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2555. break;
  2556. case GGML_TYPE_Q2_K:
  2557. ggml_mul_mat_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2558. break;
  2559. case GGML_TYPE_Q3_K:
  2560. ggml_mul_mat_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2561. break;
  2562. case GGML_TYPE_Q4_K:
  2563. ggml_mul_mat_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2564. break;
  2565. case GGML_TYPE_Q5_K:
  2566. ggml_mul_mat_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2567. break;
  2568. case GGML_TYPE_Q6_K:
  2569. ggml_mul_mat_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_ncols, src1_padded_row_size, nrows_dst, stream);
  2570. break;
  2571. default:
  2572. GGML_ASSERT(false);
  2573. break;
  2574. }
  2575. (void) src1;
  2576. (void) dst;
  2577. (void) src1_ddf_i;
  2578. }
  2579. catch (sycl::exception const &exc) {
  2580. std::cerr << exc.what() << "Exception caught at file:" << __FILE__
  2581. << ", line:" << __LINE__ << std::endl;
  2582. std::exit(1);
  2583. }