helper.hpp 118 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980
  1. //
  2. // MIT license
  3. // Copyright (C) 2024 Intel Corporation
  4. // SPDX-License-Identifier: MIT
  5. //
  6. //
  7. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  8. // See https://llvm.org/LICENSE.txt for license information.
  9. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  10. //
  11. #ifndef GGML_SYCL_DPCT_HELPER_HPP
  12. #define GGML_SYCL_DPCT_HELPER_HPP
  13. #include <sycl/sycl.hpp>
  14. #include <sycl/half_type.hpp>
  15. #include <oneapi/mkl.hpp>
  16. #include <map>
  17. #include "ggml.h"
  18. #if defined(__linux__)
  19. #include <sys/mman.h>
  20. #elif defined(_WIN64)
  21. #ifndef NOMINMAX
  22. #define NOMINMAX
  23. #endif
  24. #include <windows.h>
  25. #else
  26. #error "Only support Windows and Linux."
  27. #endif
  28. #if defined(__linux__)
  29. #include <unistd.h>
  30. #include <sys/syscall.h>
  31. #endif
  32. #if defined(_WIN64)
  33. #ifndef NOMINMAX
  34. #define NOMINMAX
  35. #endif
  36. #include <windows.h>
  37. #endif
  38. #define DPCT_COMPATIBILITY_TEMP (900)
  39. #if defined(_MSC_VER)
  40. #define __dpct_align__(n) __declspec(align(n))
  41. #define __dpct_inline__ __forceinline
  42. #else
  43. #define __dpct_align__(n) __attribute__((aligned(n)))
  44. #define __dpct_inline__ __inline__ __attribute__((always_inline))
  45. #endif
  46. #if defined(_MSC_VER)
  47. #define __dpct_noinline__ __declspec(noinline)
  48. #else
  49. #define __dpct_noinline__ __attribute__((noinline))
  50. #endif
  51. inline std::string get_device_type_name(const sycl::device &Device) {
  52. auto DeviceType = Device.get_info<sycl::info::device::device_type>();
  53. switch (DeviceType) {
  54. case sycl::info::device_type::cpu:
  55. return "cpu";
  56. case sycl::info::device_type::gpu:
  57. return "gpu";
  58. case sycl::info::device_type::host:
  59. return "host";
  60. case sycl::info::device_type::accelerator:
  61. return "acc";
  62. default:
  63. return "unknown";
  64. }
  65. }
  66. inline std::string get_device_backend_and_type(const sycl::device &device) {
  67. std::stringstream device_type;
  68. sycl::backend backend = device.get_backend();
  69. device_type << backend << ":" << get_device_type_name(device);
  70. return device_type.str();
  71. }
  72. namespace dpct
  73. {
  74. typedef sycl::queue *queue_ptr;
  75. typedef sycl::event *event_ptr;
  76. typedef char *device_ptr;
  77. typedef uint8_t byte_t;
  78. typedef sycl::buffer<byte_t> buffer_t;
  79. /// SYCL default exception handler
  80. inline auto exception_handler = [](sycl::exception_list exceptions)
  81. {
  82. for (std::exception_ptr const &e : exceptions)
  83. {
  84. try
  85. {
  86. std::rethrow_exception(e);
  87. }
  88. catch (sycl::exception const &e)
  89. {
  90. std::cerr << "Caught asynchronous SYCL exception:" << std::endl
  91. << e.what() << std::endl
  92. << "Exception caught at file:" << __FILE__
  93. << ", line:" << __LINE__ << std::endl;
  94. }
  95. }
  96. };
  97. enum error_code
  98. {
  99. success = 0,
  100. default_error = 999
  101. };
  102. enum memcpy_direction
  103. {
  104. host_to_host,
  105. host_to_device,
  106. device_to_host,
  107. device_to_device,
  108. automatic
  109. };
  110. enum memory_region
  111. {
  112. global = 0, // device global memory
  113. constant, // device constant memory
  114. local, // device local memory
  115. shared, // memory which can be accessed by host and device
  116. };
  117. enum class library_data_t : unsigned char
  118. {
  119. real_float = 0,
  120. complex_float,
  121. real_double,
  122. complex_double,
  123. real_half,
  124. complex_half,
  125. real_bfloat16,
  126. complex_bfloat16,
  127. real_int4,
  128. complex_int4,
  129. real_uint4,
  130. complex_uint4,
  131. real_int8,
  132. complex_int8,
  133. real_uint8,
  134. complex_uint8,
  135. real_int16,
  136. complex_int16,
  137. real_uint16,
  138. complex_uint16,
  139. real_int32,
  140. complex_int32,
  141. real_uint32,
  142. complex_uint32,
  143. real_int64,
  144. complex_int64,
  145. real_uint64,
  146. complex_uint64,
  147. real_int8_4,
  148. real_int8_32,
  149. real_uint8_4,
  150. library_data_t_size
  151. };
  152. template <typename T>
  153. struct DataType
  154. {
  155. using T2 = T;
  156. };
  157. template <typename T>
  158. struct DataType<sycl::vec<T, 2>>
  159. {
  160. using T2 = std::complex<T>;
  161. };
  162. static void destroy_event(event_ptr event)
  163. {
  164. delete event;
  165. }
  166. static inline unsigned int get_tid()
  167. {
  168. #if defined(__linux__)
  169. return syscall(SYS_gettid);
  170. #elif defined(_WIN64)
  171. return GetCurrentThreadId();
  172. #else
  173. #error "Only support Windows and Linux."
  174. #endif
  175. }
  176. namespace detail
  177. {
  178. static void get_version(const sycl::device &dev, int &major, int &minor)
  179. {
  180. // Version string has the following format:
  181. // a. OpenCL<space><major.minor><space><vendor-specific-information>
  182. // b. <major.minor>
  183. // c. <AmdGcnArchName> e.g gfx1030
  184. std::string ver;
  185. ver = dev.get_info<sycl::info::device::version>();
  186. std::string::size_type i = 0;
  187. while (i < ver.size()) {
  188. if (isdigit(ver[i]))
  189. break;
  190. i++;
  191. }
  192. major = std::stoi(&(ver[i]));
  193. while (i < ver.size()) {
  194. if (ver[i] == '.')
  195. break;
  196. i++;
  197. }
  198. if (i < ver.size()) {
  199. // a. and b.
  200. i++;
  201. minor = std::stoi(&(ver[i]));
  202. } else {
  203. // c.
  204. minor = 0;
  205. }
  206. }
  207. template <typename tag, typename T>
  208. class generic_error_type
  209. {
  210. public:
  211. generic_error_type() = default;
  212. generic_error_type(T value) : value{value} {}
  213. operator T() const { return value; }
  214. private:
  215. T value;
  216. };
  217. } // namespace detail
  218. /// Pitched 2D/3D memory data.
  219. class pitched_data
  220. {
  221. public:
  222. pitched_data() : pitched_data(nullptr, 0, 0, 0) {}
  223. pitched_data(void *data, size_t pitch, size_t x, size_t y)
  224. : _data(data), _pitch(pitch), _x(x), _y(y) {}
  225. void *get_data_ptr() { return _data; }
  226. void set_data_ptr(void *data) { _data = data; }
  227. size_t get_pitch() { return _pitch; }
  228. void set_pitch(size_t pitch) { _pitch = pitch; }
  229. size_t get_x() { return _x; }
  230. void set_x(size_t x) { _x = x; };
  231. size_t get_y() { return _y; }
  232. void set_y(size_t y) { _y = y; }
  233. private:
  234. void *_data;
  235. size_t _pitch, _x, _y;
  236. };
  237. class device_info
  238. {
  239. public:
  240. // get interface
  241. const char *get_name() const { return _name; }
  242. char *get_name() { return _name; }
  243. template <typename WorkItemSizesTy = sycl::range<3>,
  244. std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
  245. std::is_same_v<WorkItemSizesTy, int *>,
  246. int> = 0>
  247. auto get_max_work_item_sizes() const
  248. {
  249. if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
  250. return sycl::range<3>(_max_work_item_sizes_i[0],
  251. _max_work_item_sizes_i[1],
  252. _max_work_item_sizes_i[2]);
  253. else
  254. {
  255. return _max_work_item_sizes_i;
  256. }
  257. }
  258. template <typename WorkItemSizesTy = sycl::range<3>,
  259. std::enable_if_t<std::is_same_v<WorkItemSizesTy, sycl::range<3>> ||
  260. std::is_same_v<WorkItemSizesTy, int *>,
  261. int> = 0>
  262. auto get_max_work_item_sizes()
  263. {
  264. if constexpr (std::is_same_v<WorkItemSizesTy, sycl::range<3>>)
  265. return sycl::range<3>(_max_work_item_sizes_i[0],
  266. _max_work_item_sizes_i[1],
  267. _max_work_item_sizes_i[2]);
  268. else
  269. {
  270. return _max_work_item_sizes_i;
  271. }
  272. }
  273. bool get_host_unified_memory() const { return _host_unified_memory; }
  274. int get_major_version() const { return _major; }
  275. int get_minor_version() const { return _minor; }
  276. int get_integrated() const { return _integrated; }
  277. int get_max_clock_frequency() const { return _frequency; }
  278. int get_max_compute_units() const { return _max_compute_units; }
  279. int get_max_work_group_size() const { return _max_work_group_size; }
  280. int get_max_sub_group_size() const { return _max_sub_group_size; }
  281. int get_max_work_items_per_compute_unit() const
  282. {
  283. return _max_work_items_per_compute_unit;
  284. }
  285. int get_max_register_size_per_work_group() const
  286. {
  287. return _max_register_size_per_work_group;
  288. }
  289. template <typename NDRangeSizeTy = size_t *,
  290. std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
  291. std::is_same_v<NDRangeSizeTy, int *>,
  292. int> = 0>
  293. auto get_max_nd_range_size() const
  294. {
  295. if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
  296. return _max_nd_range_size;
  297. else
  298. return _max_nd_range_size_i;
  299. }
  300. template <typename NDRangeSizeTy = size_t *,
  301. std::enable_if_t<std::is_same_v<NDRangeSizeTy, size_t *> ||
  302. std::is_same_v<NDRangeSizeTy, int *>,
  303. int> = 0>
  304. auto get_max_nd_range_size()
  305. {
  306. if constexpr (std::is_same_v<NDRangeSizeTy, size_t *>)
  307. return _max_nd_range_size;
  308. else
  309. return _max_nd_range_size_i;
  310. }
  311. size_t get_global_mem_size() const { return _global_mem_size; }
  312. size_t get_local_mem_size() const { return _local_mem_size; }
  313. size_t get_max_mem_alloc_size() const { return _max_mem_alloc_size; }
  314. /// Returns the maximum clock rate of device's global memory in kHz. If
  315. /// compiler does not support this API then returns default value 3200000 kHz.
  316. unsigned int get_memory_clock_rate() const { return _memory_clock_rate; }
  317. /// Returns the maximum bus width between device and memory in bits. If
  318. /// compiler does not support this API then returns default value 64 bits.
  319. unsigned int get_memory_bus_width() const { return _memory_bus_width; }
  320. uint32_t get_device_id() const { return _device_id; }
  321. std::array<unsigned char, 16> get_uuid() const { return _uuid; }
  322. /// Returns global memory cache size in bytes.
  323. unsigned int get_global_mem_cache_size() const
  324. {
  325. return _global_mem_cache_size;
  326. }
  327. // set interface
  328. void set_name(const char *name)
  329. {
  330. size_t length = strlen(name);
  331. if (length < 256)
  332. {
  333. std::memcpy(_name, name, length + 1);
  334. }
  335. else
  336. {
  337. std::memcpy(_name, name, 255);
  338. _name[255] = '\0';
  339. }
  340. }
  341. void set_max_work_item_sizes(const sycl::range<3> max_work_item_sizes)
  342. {
  343. for (int i = 0; i < 3; ++i)
  344. _max_work_item_sizes_i[i] = max_work_item_sizes[i];
  345. }
  346. [[deprecated]] void
  347. set_max_work_item_sizes(const sycl::id<3> max_work_item_sizes)
  348. {
  349. for (int i = 0; i < 3; ++i)
  350. {
  351. _max_work_item_sizes_i[i] = max_work_item_sizes[i];
  352. }
  353. }
  354. void set_host_unified_memory(bool host_unified_memory)
  355. {
  356. _host_unified_memory = host_unified_memory;
  357. }
  358. void set_major_version(int major) { _major = major; }
  359. void set_minor_version(int minor) { _minor = minor; }
  360. void set_integrated(int integrated) { _integrated = integrated; }
  361. void set_max_clock_frequency(int frequency) { _frequency = frequency; }
  362. void set_max_compute_units(int max_compute_units)
  363. {
  364. _max_compute_units = max_compute_units;
  365. }
  366. void set_global_mem_size(size_t global_mem_size)
  367. {
  368. _global_mem_size = global_mem_size;
  369. }
  370. void set_local_mem_size(size_t local_mem_size)
  371. {
  372. _local_mem_size = local_mem_size;
  373. }
  374. void set_max_mem_alloc_size(size_t max_mem_alloc_size)
  375. {
  376. _max_mem_alloc_size = max_mem_alloc_size;
  377. }
  378. void set_max_work_group_size(int max_work_group_size)
  379. {
  380. _max_work_group_size = max_work_group_size;
  381. }
  382. void set_max_sub_group_size(int max_sub_group_size)
  383. {
  384. _max_sub_group_size = max_sub_group_size;
  385. }
  386. void
  387. set_max_work_items_per_compute_unit(int max_work_items_per_compute_unit)
  388. {
  389. _max_work_items_per_compute_unit = max_work_items_per_compute_unit;
  390. }
  391. void set_max_nd_range_size(int max_nd_range_size[])
  392. {
  393. for (int i = 0; i < 3; i++)
  394. {
  395. _max_nd_range_size[i] = max_nd_range_size[i];
  396. _max_nd_range_size_i[i] = max_nd_range_size[i];
  397. }
  398. }
  399. void set_memory_clock_rate(unsigned int memory_clock_rate)
  400. {
  401. _memory_clock_rate = memory_clock_rate;
  402. }
  403. void set_memory_bus_width(unsigned int memory_bus_width)
  404. {
  405. _memory_bus_width = memory_bus_width;
  406. }
  407. void
  408. set_max_register_size_per_work_group(int max_register_size_per_work_group)
  409. {
  410. _max_register_size_per_work_group = max_register_size_per_work_group;
  411. }
  412. void set_device_id(uint32_t device_id)
  413. {
  414. _device_id = device_id;
  415. }
  416. void set_uuid(std::array<unsigned char, 16> uuid)
  417. {
  418. _uuid = std::move(uuid);
  419. }
  420. void set_global_mem_cache_size(unsigned int global_mem_cache_size)
  421. {
  422. _global_mem_cache_size = global_mem_cache_size;
  423. }
  424. private:
  425. char _name[256];
  426. int _max_work_item_sizes_i[3];
  427. bool _host_unified_memory = false;
  428. int _major;
  429. int _minor;
  430. int _integrated = 0;
  431. int _frequency;
  432. // Set estimated value 3200000 kHz as default value.
  433. unsigned int _memory_clock_rate = 3200000;
  434. // Set estimated value 64 bits as default value.
  435. unsigned int _memory_bus_width = 64;
  436. unsigned int _global_mem_cache_size;
  437. int _max_compute_units;
  438. int _max_work_group_size;
  439. int _max_sub_group_size;
  440. int _max_work_items_per_compute_unit;
  441. int _max_register_size_per_work_group;
  442. size_t _global_mem_size;
  443. size_t _local_mem_size;
  444. size_t _max_mem_alloc_size;
  445. size_t _max_nd_range_size[3];
  446. int _max_nd_range_size_i[3];
  447. uint32_t _device_id;
  448. std::array<unsigned char, 16> _uuid;
  449. };
  450. static int get_major_version(const sycl::device &dev)
  451. {
  452. int major, minor;
  453. detail::get_version(dev, major, minor);
  454. return major;
  455. }
  456. static int get_minor_version(const sycl::device &dev)
  457. {
  458. int major, minor;
  459. detail::get_version(dev, major, minor);
  460. return minor;
  461. }
  462. static void get_device_info(device_info &out, const sycl::device &dev)
  463. {
  464. device_info prop;
  465. prop.set_name(dev.get_info<sycl::info::device::name>().c_str());
  466. int major, minor;
  467. detail::get_version(dev, major, minor);
  468. prop.set_major_version(major);
  469. prop.set_minor_version(minor);
  470. prop.set_max_work_item_sizes(
  471. #if (__SYCL_COMPILER_VERSION && __SYCL_COMPILER_VERSION < 20220902)
  472. // oneAPI DPC++ compiler older than 2022/09/02, where max_work_item_sizes
  473. // is an enum class element
  474. dev.get_info<sycl::info::device::max_work_item_sizes>());
  475. #else
  476. // SYCL 2020-conformant code, max_work_item_sizes is a struct templated by
  477. // an int
  478. dev.get_info<sycl::info::device::max_work_item_sizes<3>>());
  479. #endif
  480. prop.set_host_unified_memory(dev.has(sycl::aspect::usm_host_allocations));
  481. prop.set_max_clock_frequency(
  482. dev.get_info<sycl::info::device::max_clock_frequency>() * 1000);
  483. prop.set_max_compute_units(
  484. dev.get_info<sycl::info::device::max_compute_units>());
  485. prop.set_max_work_group_size(
  486. dev.get_info<sycl::info::device::max_work_group_size>());
  487. prop.set_global_mem_size(dev.get_info<sycl::info::device::global_mem_size>());
  488. prop.set_local_mem_size(dev.get_info<sycl::info::device::local_mem_size>());
  489. prop.set_max_mem_alloc_size(dev.get_info<sycl::info::device::max_mem_alloc_size>());
  490. #if (defined(SYCL_EXT_INTEL_DEVICE_INFO) && SYCL_EXT_INTEL_DEVICE_INFO >= 6)
  491. if (dev.has(sycl::aspect::ext_intel_memory_clock_rate))
  492. {
  493. unsigned int tmp =
  494. dev.get_info<sycl::ext::intel::info::device::memory_clock_rate>();
  495. if (tmp != 0)
  496. prop.set_memory_clock_rate(1000 * tmp);
  497. }
  498. if (dev.has(sycl::aspect::ext_intel_memory_bus_width))
  499. {
  500. prop.set_memory_bus_width(
  501. dev.get_info<sycl::ext::intel::info::device::memory_bus_width>());
  502. }
  503. if (dev.has(sycl::aspect::ext_intel_device_id))
  504. {
  505. prop.set_device_id(
  506. dev.get_info<sycl::ext::intel::info::device::device_id>());
  507. }
  508. if (dev.has(sycl::aspect::ext_intel_device_info_uuid))
  509. {
  510. prop.set_uuid(dev.get_info<sycl::ext::intel::info::device::uuid>());
  511. }
  512. #elif defined(_MSC_VER) && !defined(__clang__)
  513. #pragma message("get_device_info: querying memory_clock_rate and \
  514. memory_bus_width are not supported by the compiler used. \
  515. Use 3200000 kHz as memory_clock_rate default value. \
  516. Use 64 bits as memory_bus_width default value.")
  517. #else
  518. #warning "get_device_info: querying memory_clock_rate and \
  519. memory_bus_width are not supported by the compiler used. \
  520. Use 3200000 kHz as memory_clock_rate default value. \
  521. Use 64 bits as memory_bus_width default value."
  522. #endif
  523. size_t max_sub_group_size = 1;
  524. std::vector<size_t> sub_group_sizes =
  525. dev.get_info<sycl::info::device::sub_group_sizes>();
  526. for (const auto &sub_group_size : sub_group_sizes)
  527. {
  528. if (max_sub_group_size < sub_group_size)
  529. max_sub_group_size = sub_group_size;
  530. }
  531. prop.set_max_sub_group_size(max_sub_group_size);
  532. prop.set_max_work_items_per_compute_unit(
  533. dev.get_info<sycl::info::device::max_work_group_size>());
  534. int max_nd_range_size[] = {0x7FFFFFFF, 0x7FFFFFFF, 0x7FFFFFFF};
  535. prop.set_max_nd_range_size(max_nd_range_size);
  536. // Estimates max register size per work group, feel free to update the value
  537. // according to device properties.
  538. prop.set_max_register_size_per_work_group(65536);
  539. prop.set_global_mem_cache_size(
  540. dev.get_info<sycl::info::device::global_mem_cache_size>());
  541. out = prop;
  542. }
  543. /// dpct device extension
  544. class device_ext : public sycl::device
  545. {
  546. typedef std::mutex mutex_type;
  547. public:
  548. device_ext() : sycl::device(), _ctx(*this) {}
  549. ~device_ext()
  550. {
  551. std::lock_guard<mutex_type> lock(m_mutex);
  552. clear_queues();
  553. }
  554. device_ext(const sycl::device &base) : sycl::device(base), _ctx(*this)
  555. {
  556. std::lock_guard<mutex_type> lock(m_mutex);
  557. init_queues();
  558. }
  559. int is_native_atomic_supported() { return 0; }
  560. int get_major_version() const
  561. {
  562. return dpct::get_major_version(*this);
  563. }
  564. int get_minor_version() const
  565. {
  566. return dpct::get_minor_version(*this);
  567. }
  568. int get_max_compute_units() const
  569. {
  570. return get_device_info().get_max_compute_units();
  571. }
  572. /// Return the maximum clock frequency of this device in KHz.
  573. int get_max_clock_frequency() const
  574. {
  575. return get_device_info().get_max_clock_frequency();
  576. }
  577. int get_integrated() const { return get_device_info().get_integrated(); }
  578. int get_max_sub_group_size() const
  579. {
  580. return get_device_info().get_max_sub_group_size();
  581. }
  582. int get_max_register_size_per_work_group() const
  583. {
  584. return get_device_info().get_max_register_size_per_work_group();
  585. }
  586. int get_max_work_group_size() const
  587. {
  588. return get_device_info().get_max_work_group_size();
  589. }
  590. int get_mem_base_addr_align() const
  591. {
  592. return get_info<sycl::info::device::mem_base_addr_align>();
  593. }
  594. size_t get_global_mem_size() const
  595. {
  596. return get_device_info().get_global_mem_size();
  597. }
  598. size_t get_max_mem_alloc_size() const
  599. {
  600. return get_device_info().get_max_mem_alloc_size();
  601. }
  602. /// Get the number of bytes of free and total memory on the SYCL device.
  603. /// \param [out] free_memory The number of bytes of free memory on the SYCL device.
  604. /// \param [out] total_memory The number of bytes of total memory on the SYCL device.
  605. void get_memory_info(size_t &free_memory, size_t &total_memory)
  606. {
  607. total_memory = get_device_info().get_global_mem_size();
  608. const char *warning_info = "get_memory_info: [warning] ext_intel_free_memory is not "
  609. "supported (export/set ZES_ENABLE_SYSMAN=1 to support), "
  610. "use total memory as free memory";
  611. #if (defined(__SYCL_COMPILER_VERSION) && __SYCL_COMPILER_VERSION >= 20221105)
  612. if (!has(sycl::aspect::ext_intel_free_memory))
  613. {
  614. std::cerr << warning_info << std::endl;
  615. free_memory = total_memory;
  616. }
  617. else
  618. {
  619. free_memory = get_info<sycl::ext::intel::info::device::free_memory>();
  620. }
  621. #else
  622. std::cerr << warning_info << std::endl;
  623. free_memory = total_memory;
  624. #if defined(_MSC_VER) && !defined(__clang__)
  625. #pragma message("Querying the number of bytes of free memory is not supported")
  626. #else
  627. #warning "Querying the number of bytes of free memory is not supported"
  628. #endif
  629. #endif
  630. }
  631. void get_device_info(device_info &out) const
  632. {
  633. dpct::get_device_info(out, *this);
  634. }
  635. device_info get_device_info() const
  636. {
  637. device_info prop;
  638. dpct::get_device_info(prop, *this);
  639. return prop;
  640. }
  641. void reset()
  642. {
  643. std::lock_guard<mutex_type> lock(m_mutex);
  644. clear_queues();
  645. init_queues();
  646. }
  647. sycl::queue &in_order_queue() { return *_q_in_order; }
  648. sycl::queue &out_of_order_queue() { return *_q_out_of_order; }
  649. sycl::queue &default_queue()
  650. {
  651. return in_order_queue();
  652. }
  653. void queues_wait_and_throw()
  654. {
  655. std::unique_lock<mutex_type> lock(m_mutex);
  656. std::vector<std::shared_ptr<sycl::queue>> current_queues(
  657. _queues);
  658. lock.unlock();
  659. for (const auto &q : current_queues)
  660. {
  661. q->wait_and_throw();
  662. }
  663. // Guard the destruct of current_queues to make sure the ref count is safe.
  664. lock.lock();
  665. }
  666. sycl::queue *create_queue(bool enable_exception_handler = false)
  667. {
  668. return create_in_order_queue(enable_exception_handler);
  669. }
  670. sycl::queue *create_queue(sycl::context context, sycl::device device,
  671. bool enable_exception_handler = false) {
  672. return create_in_order_queue(context, device, enable_exception_handler);
  673. }
  674. sycl::queue *create_in_order_queue(bool enable_exception_handler = false) {
  675. std::lock_guard<mutex_type> lock(m_mutex);
  676. return create_queue_impl(enable_exception_handler,
  677. sycl::property::queue::in_order());
  678. }
  679. sycl::queue *create_in_order_queue(sycl::context context, sycl::device device,
  680. bool enable_exception_handler = false) {
  681. std::lock_guard<mutex_type> lock(m_mutex);
  682. return create_queue_impl(context, device, enable_exception_handler,
  683. sycl::property::queue::in_order());
  684. }
  685. sycl::queue *create_out_of_order_queue(bool enable_exception_handler = false) {
  686. std::lock_guard<mutex_type> lock(m_mutex);
  687. return create_queue_impl(enable_exception_handler);
  688. }
  689. void destroy_queue(sycl::queue *&queue)
  690. {
  691. std::lock_guard<mutex_type> lock(m_mutex);
  692. _queues.erase(std::remove_if(_queues.begin(), _queues.end(),
  693. [=](const std::shared_ptr<sycl::queue> &q) -> bool
  694. {
  695. return q.get() == queue;
  696. }),
  697. _queues.end());
  698. queue = nullptr;
  699. }
  700. void set_saved_queue(sycl::queue *q)
  701. {
  702. std::lock_guard<mutex_type> lock(m_mutex);
  703. _saved_queue = q;
  704. }
  705. sycl::queue *get_saved_queue() const
  706. {
  707. std::lock_guard<mutex_type> lock(m_mutex);
  708. return _saved_queue;
  709. }
  710. sycl::context get_context() const { return _ctx; }
  711. private:
  712. void clear_queues()
  713. {
  714. _queues.clear();
  715. _q_in_order = _q_out_of_order = _saved_queue = nullptr;
  716. }
  717. void init_queues()
  718. {
  719. _q_in_order = create_queue_impl(true, sycl::property::queue::in_order());
  720. _q_out_of_order = create_queue_impl(true);
  721. _saved_queue = &default_queue();
  722. }
  723. /// Caller should acquire resource \p m_mutex before calling this function.
  724. template <class... Properties>
  725. sycl::queue *create_queue_impl(bool enable_exception_handler,
  726. Properties... properties)
  727. {
  728. sycl::async_handler eh = {};
  729. if (enable_exception_handler)
  730. {
  731. eh = exception_handler;
  732. }
  733. _queues.push_back(std::make_shared<sycl::queue>(
  734. _ctx, *this, eh,
  735. sycl::property_list(
  736. #ifdef DPCT_PROFILING_ENABLED
  737. sycl::property::queue::enable_profiling(),
  738. #endif
  739. properties...)));
  740. return _queues.back().get();
  741. }
  742. template <class... Properties>
  743. sycl::queue *create_queue_impl(sycl::context context, sycl::device device,
  744. bool enable_exception_handler,
  745. Properties... properties) {
  746. sycl::async_handler eh = {};
  747. if (enable_exception_handler) {
  748. eh = exception_handler;
  749. }
  750. _queues.push_back(std::make_shared<sycl::queue>(
  751. context, device, eh,
  752. sycl::property_list(
  753. #ifdef DPCT_PROFILING_ENABLED
  754. sycl::property::queue::enable_profiling(),
  755. #endif
  756. properties...)));
  757. return _queues.back().get();
  758. }
  759. void get_version(int &major, int &minor) const
  760. {
  761. detail::get_version(*this, major, minor);
  762. }
  763. sycl::queue *_q_in_order, *_q_out_of_order;
  764. sycl::queue *_saved_queue;
  765. sycl::context _ctx;
  766. std::vector<std::shared_ptr<sycl::queue>> _queues;
  767. mutable mutex_type m_mutex;
  768. };
  769. /// device manager
  770. class dev_mgr
  771. {
  772. public:
  773. device_ext &current_device()
  774. {
  775. unsigned int dev_id = current_device_id();
  776. check_id(dev_id);
  777. return *_devs[dev_id];
  778. }
  779. device_ext &cpu_device() const
  780. {
  781. std::lock_guard<std::recursive_mutex> lock(m_mutex);
  782. if (_cpu_device == -1)
  783. {
  784. throw std::runtime_error("no valid cpu device");
  785. }
  786. else
  787. {
  788. return *_devs[_cpu_device];
  789. }
  790. }
  791. device_ext &get_device(unsigned int id) const
  792. {
  793. std::lock_guard<std::recursive_mutex> lock(m_mutex);
  794. check_id(id);
  795. return *_devs[id];
  796. }
  797. unsigned int current_device_id() const
  798. {
  799. std::lock_guard<std::recursive_mutex> lock(m_mutex);
  800. auto it = _thread2dev_map.find(get_tid());
  801. if (it != _thread2dev_map.end())
  802. return it->second;
  803. return DEFAULT_DEVICE_ID;
  804. }
  805. /// Select device with a device ID.
  806. /// \param [in] id The id of the device which can
  807. /// be obtained through get_device_id(const sycl::device).
  808. void select_device(unsigned int id)
  809. {
  810. std::lock_guard<std::recursive_mutex> lock(m_mutex);
  811. check_id(id);
  812. _thread2dev_map[get_tid()] = id;
  813. }
  814. unsigned int device_count() { return _devs.size(); }
  815. unsigned int get_device_id(const sycl::device &dev)
  816. {
  817. unsigned int id = 0;
  818. for (auto dev_item : _devs)
  819. {
  820. if (*dev_item == dev)
  821. {
  822. break;
  823. }
  824. id++;
  825. }
  826. return id;
  827. }
  828. template <class DeviceSelector>
  829. std::enable_if_t<
  830. std::is_invocable_r_v<int, DeviceSelector, const sycl::device &>>
  831. select_device(const DeviceSelector &selector = sycl::gpu_selector_v)
  832. {
  833. sycl::device selected_device = sycl::device(selector);
  834. unsigned int selected_device_id = get_device_id(selected_device);
  835. select_device(selected_device_id);
  836. }
  837. /// Returns the instance of device manager singleton.
  838. static dev_mgr &instance()
  839. {
  840. static dev_mgr d_m;
  841. return d_m;
  842. }
  843. dev_mgr(const dev_mgr &) = delete;
  844. dev_mgr &operator=(const dev_mgr &) = delete;
  845. dev_mgr(dev_mgr &&) = delete;
  846. dev_mgr &operator=(dev_mgr &&) = delete;
  847. private:
  848. mutable std::recursive_mutex m_mutex;
  849. static bool compare_dev(sycl::device &device1, sycl::device &device2)
  850. {
  851. sycl::backend backend1 = device1.get_backend();
  852. sycl::backend backend2 = device2.get_backend();
  853. // levelzero backends always come first
  854. if(backend1 == sycl::backend::ext_oneapi_level_zero && backend2 != sycl::backend::ext_oneapi_level_zero) return true;
  855. if(backend1 != sycl::backend::ext_oneapi_level_zero && backend2 == sycl::backend::ext_oneapi_level_zero) return false;
  856. dpct::device_info prop1;
  857. dpct::get_device_info(prop1, device1);
  858. dpct::device_info prop2;
  859. dpct::get_device_info(prop2, device2);
  860. return prop1.get_max_compute_units() > prop2.get_max_compute_units();
  861. }
  862. static int convert_backend_index(std::string & backend) {
  863. if (backend == "ext_oneapi_level_zero:gpu") return 0;
  864. if (backend == "opencl:gpu") return 1;
  865. if (backend == "ext_oneapi_cuda:gpu") return 2;
  866. if (backend == "ext_oneapi_hip:gpu") return 3;
  867. if (backend == "opencl:cpu") return 4;
  868. if (backend == "opencl:acc") return 5;
  869. printf("convert_backend_index: can't handle backend=%s\n", backend.c_str());
  870. GGML_ASSERT(false);
  871. }
  872. static bool compare_backend(std::string &backend1, std::string &backend2) {
  873. return convert_backend_index(backend1) < convert_backend_index(backend2);
  874. }
  875. dev_mgr()
  876. {
  877. sycl::device default_device =
  878. sycl::device(sycl::default_selector_v);
  879. _devs.push_back(std::make_shared<device_ext>(default_device));
  880. std::vector<sycl::device> sycl_all_devs;
  881. // Collect other devices except for the default device.
  882. if (default_device.is_cpu())
  883. _cpu_device = 0;
  884. auto Platforms = sycl::platform::get_platforms();
  885. // Keep track of the number of devices per backend
  886. std::map<sycl::backend, size_t> DeviceNums;
  887. std::map<std::string, std::vector<sycl::device>> backend_devices;
  888. while (!Platforms.empty()) {
  889. auto Platform = Platforms.back();
  890. Platforms.pop_back();
  891. auto devices = Platform.get_devices();
  892. std::string backend_type = get_device_backend_and_type(devices[0]);
  893. for (const auto &device : devices) {
  894. backend_devices[backend_type].push_back(device);
  895. }
  896. }
  897. std::vector<std::string> keys;
  898. for(auto it = backend_devices.begin(); it != backend_devices.end(); ++it) {
  899. keys.push_back(it->first);
  900. }
  901. std::sort(keys.begin(), keys.end(), compare_backend);
  902. for (auto &key : keys) {
  903. std::vector<sycl::device> devs = backend_devices[key];
  904. std::sort(devs.begin(), devs.end(), compare_dev);
  905. for (const auto &dev : devs) {
  906. sycl_all_devs.push_back(dev);
  907. }
  908. }
  909. for (auto &dev : sycl_all_devs)
  910. {
  911. if (dev == default_device)
  912. {
  913. continue;
  914. }
  915. _devs.push_back(std::make_shared<device_ext>(dev));
  916. if (_cpu_device == -1 && dev.is_cpu())
  917. {
  918. _cpu_device = _devs.size() - 1;
  919. }
  920. }
  921. }
  922. void check_id(unsigned int id) const
  923. {
  924. if (id >= _devs.size())
  925. {
  926. throw std::runtime_error("invalid device id");
  927. }
  928. }
  929. std::vector<std::shared_ptr<device_ext>> _devs;
  930. /// DEFAULT_DEVICE_ID is used, if current_device_id() can not find current
  931. /// thread id in _thread2dev_map, which means default device should be used
  932. /// for the current thread.
  933. const unsigned int DEFAULT_DEVICE_ID = 0;
  934. /// thread-id to device-id map.
  935. std::map<unsigned int, unsigned int> _thread2dev_map;
  936. int _cpu_device = -1;
  937. };
  938. static inline sycl::queue &get_default_queue()
  939. {
  940. return dev_mgr::instance().current_device().default_queue();
  941. }
  942. namespace detail
  943. {
  944. enum class pointer_access_attribute
  945. {
  946. host_only = 0,
  947. device_only,
  948. host_device,
  949. end
  950. };
  951. static pointer_access_attribute get_pointer_attribute(sycl::queue &q,
  952. const void *ptr)
  953. {
  954. switch (sycl::get_pointer_type(ptr, q.get_context()))
  955. {
  956. case sycl::usm::alloc::unknown:
  957. return pointer_access_attribute::host_only;
  958. case sycl::usm::alloc::device:
  959. return pointer_access_attribute::device_only;
  960. case sycl::usm::alloc::shared:
  961. case sycl::usm::alloc::host:
  962. return pointer_access_attribute::host_device;
  963. }
  964. }
  965. template <typename ArgT>
  966. inline constexpr std::uint64_t get_type_combination_id(ArgT Val)
  967. {
  968. static_assert((unsigned char)library_data_t::library_data_t_size <=
  969. std::numeric_limits<unsigned char>::max() &&
  970. "library_data_t size exceeds limit.");
  971. static_assert(std::is_same_v<ArgT, library_data_t>, "Unsupported ArgT");
  972. return (std::uint64_t)Val;
  973. }
  974. template <typename FirstT, typename... RestT>
  975. inline constexpr std::uint64_t get_type_combination_id(FirstT FirstVal,
  976. RestT... RestVal)
  977. {
  978. static_assert((std::uint8_t)library_data_t::library_data_t_size <=
  979. std::numeric_limits<unsigned char>::max() &&
  980. "library_data_t size exceeds limit.");
  981. static_assert(sizeof...(RestT) <= 8 && "Too many parameters");
  982. static_assert(std::is_same_v<FirstT, library_data_t>, "Unsupported FirstT");
  983. return get_type_combination_id(RestVal...) << 8 | ((std::uint64_t)FirstVal);
  984. }
  985. class mem_mgr
  986. {
  987. mem_mgr()
  988. {
  989. // Reserved address space, no real memory allocation happens here.
  990. #if defined(__linux__)
  991. mapped_address_space =
  992. (byte_t *)mmap(nullptr, mapped_region_size, PROT_NONE,
  993. MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
  994. #elif defined(_WIN64)
  995. mapped_address_space = (byte_t *)VirtualAlloc(
  996. NULL, // NULL specified as the base address parameter
  997. mapped_region_size, // Size of allocation
  998. MEM_RESERVE, // Allocate reserved pages
  999. PAGE_NOACCESS); // Protection = no access
  1000. #else
  1001. #error "Only support Windows and Linux."
  1002. #endif
  1003. next_free = mapped_address_space;
  1004. };
  1005. public:
  1006. using buffer_id_t = int;
  1007. struct allocation
  1008. {
  1009. buffer_t buffer;
  1010. byte_t *alloc_ptr;
  1011. size_t size;
  1012. };
  1013. ~mem_mgr()
  1014. {
  1015. #if defined(__linux__)
  1016. munmap(mapped_address_space, mapped_region_size);
  1017. #elif defined(_WIN64)
  1018. VirtualFree(mapped_address_space, 0, MEM_RELEASE);
  1019. #else
  1020. #error "Only support Windows and Linux."
  1021. #endif
  1022. };
  1023. mem_mgr(const mem_mgr &) = delete;
  1024. mem_mgr &operator=(const mem_mgr &) = delete;
  1025. mem_mgr(mem_mgr &&) = delete;
  1026. mem_mgr &operator=(mem_mgr &&) = delete;
  1027. /// Allocate
  1028. void *mem_alloc(size_t size)
  1029. {
  1030. if (!size)
  1031. return nullptr;
  1032. std::lock_guard<std::mutex> lock(m_mutex);
  1033. if (next_free + size > mapped_address_space + mapped_region_size)
  1034. {
  1035. throw std::runtime_error("dpct_malloc: out of memory for virtual memory pool");
  1036. }
  1037. // Allocation
  1038. sycl::range<1> r(size);
  1039. buffer_t buf(r);
  1040. allocation A{buf, next_free, size};
  1041. // Map allocation to device pointer
  1042. void *result = next_free;
  1043. m_map.emplace(next_free + size, A);
  1044. // Update pointer to the next free space.
  1045. next_free += (size + extra_padding + alignment - 1) & ~(alignment - 1);
  1046. return result;
  1047. }
  1048. /// Deallocate
  1049. void mem_free(const void *ptr)
  1050. {
  1051. if (!ptr)
  1052. return;
  1053. std::lock_guard<std::mutex> lock(m_mutex);
  1054. auto it = get_map_iterator(ptr);
  1055. m_map.erase(it);
  1056. }
  1057. /// map: device pointer -> allocation(buffer, alloc_ptr, size)
  1058. allocation translate_ptr(const void *ptr)
  1059. {
  1060. std::lock_guard<std::mutex> lock(m_mutex);
  1061. auto it = get_map_iterator(ptr);
  1062. return it->second;
  1063. }
  1064. /// Check if the pointer represents device pointer or not.
  1065. bool is_device_ptr(const void *ptr) const
  1066. {
  1067. std::lock_guard<std::mutex> lock(m_mutex);
  1068. return (mapped_address_space <= ptr) &&
  1069. (ptr < mapped_address_space + mapped_region_size);
  1070. }
  1071. /// Returns the instance of memory manager singleton.
  1072. static mem_mgr &instance()
  1073. {
  1074. static mem_mgr m;
  1075. return m;
  1076. }
  1077. private:
  1078. std::map<byte_t *, allocation> m_map;
  1079. mutable std::mutex m_mutex;
  1080. byte_t *mapped_address_space;
  1081. byte_t *next_free;
  1082. const size_t mapped_region_size = 128ull * 1024 * 1024 * 1024;
  1083. const size_t alignment = 256;
  1084. /// This padding may be defined to some positive value to debug
  1085. /// out of bound accesses.
  1086. const size_t extra_padding = 0;
  1087. std::map<byte_t *, allocation>::iterator get_map_iterator(const void *ptr)
  1088. {
  1089. auto it = m_map.upper_bound((byte_t *)ptr);
  1090. if (it == m_map.end())
  1091. {
  1092. // Not a virtual pointer.
  1093. throw std::runtime_error("can not get buffer from non-virtual pointer");
  1094. }
  1095. const allocation &alloc = it->second;
  1096. if (ptr < alloc.alloc_ptr)
  1097. {
  1098. // Out of bound.
  1099. // This may happen if there's a gap between allocations due to alignment
  1100. // or extra padding and pointer points to this gap.
  1101. throw std::runtime_error("invalid virtual pointer");
  1102. }
  1103. return it;
  1104. }
  1105. };
  1106. template <class T, memory_region Memory, size_t Dimension>
  1107. class accessor;
  1108. template <memory_region Memory, class T = byte_t>
  1109. class memory_traits
  1110. {
  1111. public:
  1112. static constexpr sycl::access::target target =
  1113. sycl::access::target::device;
  1114. static constexpr sycl::access_mode mode =
  1115. (Memory == constant) ? sycl::access_mode::read
  1116. : sycl::access_mode::read_write;
  1117. static constexpr size_t type_size = sizeof(T);
  1118. using element_t =
  1119. typename std::conditional<Memory == constant, const T, T>::type;
  1120. using value_t = typename std::remove_cv<T>::type;
  1121. template <size_t Dimension = 1>
  1122. using accessor_t = typename std::conditional<
  1123. Memory == local, sycl::local_accessor<value_t, Dimension>,
  1124. sycl::accessor<T, Dimension, mode, target>>::type;
  1125. using pointer_t = T *;
  1126. };
  1127. static inline void *dpct_malloc(size_t size, sycl::queue &q)
  1128. {
  1129. return sycl::malloc_device(size, q.get_device(), q.get_context());
  1130. }
  1131. #define PITCH_DEFAULT_ALIGN(x) (((x) + 31) & ~(0x1F))
  1132. static inline void *dpct_malloc(size_t &pitch, size_t x, size_t y, size_t z,
  1133. sycl::queue &q)
  1134. {
  1135. pitch = PITCH_DEFAULT_ALIGN(x);
  1136. return dpct_malloc(pitch * y * z, q);
  1137. }
  1138. /**
  1139. * @brief Sets \p value to the first \p size elements starting from \p dev_ptr in \p q.
  1140. * @tparam valueT The type of the element to be set.
  1141. * @param [in] q The queue in which the operation is done.
  1142. * @param [in] dev_ptr Pointer to the virtual device memory address.
  1143. * @param [in] value The value to be set.
  1144. * @param [in] size Number of elements to be set to the value.
  1145. * @return An event representing the memset operation.
  1146. */
  1147. template <typename valueT>
  1148. static inline sycl::event dpct_memset(sycl::queue &q, void *dev_ptr,
  1149. valueT value, size_t size)
  1150. {
  1151. return q.fill(dev_ptr, value, size);
  1152. }
  1153. /**
  1154. * @brief Sets \p value to the 3D memory region pointed by \p data in \p q.
  1155. * @tparam valueT The type of the element to be set.
  1156. * @param [in] q The queue in which the operation is done.
  1157. * @param [in] data Pointer to the pitched device memory region.
  1158. * @param [in] value The value to be set.
  1159. * @param [in] size 3D memory region by number of elements.
  1160. * @return An event list representing the memset operations.
  1161. */
  1162. template <typename valueT>
  1163. static inline std::vector<sycl::event>
  1164. dpct_memset(sycl::queue &q, pitched_data data, valueT value,
  1165. sycl::range<3> size)
  1166. {
  1167. std::vector<sycl::event> event_list;
  1168. size_t slice = data.get_pitch() * data.get_y();
  1169. unsigned char *data_surface = (unsigned char *)data.get_data_ptr();
  1170. for (size_t z = 0; z < size.get(2); ++z)
  1171. {
  1172. unsigned char *data_ptr = data_surface;
  1173. for (size_t y = 0; y < size.get(1); ++y)
  1174. {
  1175. event_list.push_back(dpct_memset(q, data_ptr, value, size.get(0)));
  1176. data_ptr += data.get_pitch();
  1177. }
  1178. data_surface += slice;
  1179. }
  1180. return event_list;
  1181. }
  1182. /**
  1183. * @brief Sets \p val to the pitched 2D memory region pointed by \p ptr in \p q.
  1184. * @tparam valueT The type of the element to be set.
  1185. * @param [in] q The queue in which the operation is done.
  1186. * @param [in] ptr Pointer to the virtual device memory.
  1187. * @param [in] pitch The pitch size by number of elements, including padding.
  1188. * @param [in] val The value to be set.
  1189. * @param [in] x The width of memory region by number of elements.
  1190. * @param [in] y The height of memory region by number of elements.
  1191. * @return An event list representing the memset operations.
  1192. */
  1193. template <typename valueT>
  1194. static inline std::vector<sycl::event>
  1195. dpct_memset(sycl::queue &q, void *ptr, size_t pitch, valueT val, size_t x,
  1196. size_t y)
  1197. {
  1198. return dpct_memset(q, pitched_data(ptr, pitch, x, 1), val,
  1199. sycl::range<3>(x, y, 1));
  1200. }
  1201. static memcpy_direction deduce_memcpy_direction(sycl::queue &q, void *to_ptr,
  1202. const void *from_ptr,
  1203. memcpy_direction dir)
  1204. {
  1205. switch (dir)
  1206. {
  1207. case memcpy_direction::host_to_host:
  1208. case memcpy_direction::host_to_device:
  1209. case memcpy_direction::device_to_host:
  1210. case memcpy_direction::device_to_device:
  1211. return dir;
  1212. case memcpy_direction::automatic:
  1213. {
  1214. // table[to_attribute][from_attribute]
  1215. static const memcpy_direction
  1216. direction_table[static_cast<unsigned>(pointer_access_attribute::end)]
  1217. [static_cast<unsigned>(pointer_access_attribute::end)] =
  1218. {{memcpy_direction::host_to_host,
  1219. memcpy_direction::device_to_host,
  1220. memcpy_direction::host_to_host},
  1221. {memcpy_direction::host_to_device,
  1222. memcpy_direction::device_to_device,
  1223. memcpy_direction::device_to_device},
  1224. {memcpy_direction::host_to_host,
  1225. memcpy_direction::device_to_device,
  1226. memcpy_direction::device_to_device}};
  1227. return direction_table[static_cast<unsigned>(get_pointer_attribute(
  1228. q, to_ptr))][static_cast<unsigned>(get_pointer_attribute(q, from_ptr))];
  1229. }
  1230. default:
  1231. throw std::runtime_error("dpct_memcpy: invalid direction value");
  1232. }
  1233. }
  1234. static sycl::event
  1235. dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
  1236. memcpy_direction direction,
  1237. const std::vector<sycl::event> &dep_events = {})
  1238. {
  1239. if (!size)
  1240. return sycl::event{};
  1241. return q.memcpy(to_ptr, from_ptr, size, dep_events);
  1242. GGML_UNUSED(direction);
  1243. }
  1244. // Get actual copy range and make sure it will not exceed range.
  1245. static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
  1246. size_t pitch)
  1247. {
  1248. return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
  1249. }
  1250. static inline size_t get_offset(sycl::id<3> id, size_t slice,
  1251. size_t pitch)
  1252. {
  1253. return slice * id.get(2) + pitch * id.get(1) + id.get(0);
  1254. }
  1255. /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
  1256. /// and \p from_range to another specified by \p to_ptr and \p to_range.
  1257. static inline std::vector<sycl::event>
  1258. dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
  1259. sycl::range<3> to_range, sycl::range<3> from_range,
  1260. sycl::id<3> to_id, sycl::id<3> from_id,
  1261. sycl::range<3> size, memcpy_direction direction,
  1262. const std::vector<sycl::event> &dep_events = {})
  1263. {
  1264. // RAII for host pointer
  1265. class host_buffer
  1266. {
  1267. void *_buf;
  1268. size_t _size;
  1269. sycl::queue &_q;
  1270. const std::vector<sycl::event> &_deps; // free operation depends
  1271. public:
  1272. host_buffer(size_t size, sycl::queue &q,
  1273. const std::vector<sycl::event> &deps)
  1274. : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
  1275. void *get_ptr() const { return _buf; }
  1276. size_t get_size() const { return _size; }
  1277. ~host_buffer()
  1278. {
  1279. if (_buf)
  1280. {
  1281. _q.submit([&](sycl::handler &cgh)
  1282. {
  1283. cgh.depends_on(_deps);
  1284. cgh.host_task([buf = _buf] { std::free(buf); }); });
  1285. }
  1286. }
  1287. };
  1288. std::vector<sycl::event> event_list;
  1289. size_t to_slice = to_range.get(1) * to_range.get(0),
  1290. from_slice = from_range.get(1) * from_range.get(0);
  1291. unsigned char *to_surface =
  1292. (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
  1293. const unsigned char *from_surface =
  1294. (const unsigned char *)from_ptr +
  1295. get_offset(from_id, from_slice, from_range.get(0));
  1296. if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
  1297. {
  1298. return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
  1299. direction, dep_events)};
  1300. }
  1301. direction = deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
  1302. size_t size_slice = size.get(1) * size.get(0);
  1303. switch (direction)
  1304. {
  1305. case host_to_host:
  1306. for (size_t z = 0; z < size.get(2); ++z)
  1307. {
  1308. unsigned char *to_ptr = to_surface;
  1309. const unsigned char *from_ptr = from_surface;
  1310. if (to_range.get(0) == from_range.get(0) &&
  1311. to_range.get(0) == size.get(0))
  1312. {
  1313. event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
  1314. direction, dep_events));
  1315. }
  1316. else
  1317. {
  1318. for (size_t y = 0; y < size.get(1); ++y)
  1319. {
  1320. event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
  1321. direction, dep_events));
  1322. to_ptr += to_range.get(0);
  1323. from_ptr += from_range.get(0);
  1324. }
  1325. }
  1326. to_surface += to_slice;
  1327. from_surface += from_slice;
  1328. }
  1329. break;
  1330. case host_to_device:
  1331. {
  1332. host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
  1333. event_list);
  1334. std::vector<sycl::event> host_events;
  1335. if (to_slice == size_slice)
  1336. {
  1337. // Copy host data to a temp host buffer with the shape of target.
  1338. host_events =
  1339. dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
  1340. sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
  1341. host_to_host, dep_events);
  1342. }
  1343. else
  1344. {
  1345. // Copy host data to a temp host buffer with the shape of target.
  1346. host_events = dpct_memcpy(
  1347. q, buf.get_ptr(), from_surface, to_range, from_range,
  1348. sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
  1349. // If has padding data, not sure whether it is useless. So fill temp
  1350. // buffer with it.
  1351. std::vector<sycl::event>{
  1352. dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
  1353. device_to_host, dep_events)});
  1354. }
  1355. // Copy from temp host buffer to device with only one submit.
  1356. event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
  1357. buf.get_size(), host_to_device,
  1358. host_events));
  1359. break;
  1360. }
  1361. case device_to_host:
  1362. {
  1363. host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
  1364. event_list);
  1365. // Copy from host temp buffer to host target with reshaping.
  1366. event_list = dpct_memcpy(
  1367. q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
  1368. sycl::id<3>(0, 0, 0), size, host_to_host,
  1369. // Copy from device to temp host buffer with only one submit.
  1370. std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
  1371. buf.get_size(),
  1372. device_to_host, dep_events)});
  1373. break;
  1374. }
  1375. case device_to_device:
  1376. event_list.push_back(q.submit([&](sycl::handler &cgh){
  1377. cgh.depends_on(dep_events);
  1378. cgh.parallel_for<class dpct_memcpy_3d_detail>(
  1379. size,
  1380. [=](sycl::id<3> id) {
  1381. to_surface[get_offset(id, to_slice, to_range.get(0))] =
  1382. from_surface[get_offset(id, from_slice, from_range.get(0))];
  1383. }); }));
  1384. break;
  1385. default:
  1386. throw std::runtime_error("dpct_memcpy: invalid direction value");
  1387. }
  1388. return event_list;
  1389. }
  1390. /// memcpy 2D/3D matrix specified by pitched_data.
  1391. static inline std::vector<sycl::event>
  1392. dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
  1393. pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
  1394. memcpy_direction direction = automatic)
  1395. {
  1396. return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
  1397. sycl::range<3>(to.get_pitch(), to.get_y(), 1),
  1398. sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
  1399. size, direction);
  1400. }
  1401. /// memcpy 2D matrix with pitch.
  1402. static inline std::vector<sycl::event>
  1403. dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
  1404. size_t to_pitch, size_t from_pitch, size_t x, size_t y,
  1405. memcpy_direction direction = automatic)
  1406. {
  1407. return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
  1408. sycl::range<3>(from_pitch, y, 1),
  1409. sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
  1410. sycl::range<3>(x, y, 1), direction);
  1411. }
  1412. namespace deprecated
  1413. {
  1414. template <typename T, sycl::usm::alloc AllocKind>
  1415. class usm_allocator
  1416. {
  1417. private:
  1418. using Alloc = sycl::usm_allocator<T, AllocKind>;
  1419. Alloc _impl;
  1420. public:
  1421. using value_type = typename std::allocator_traits<Alloc>::value_type;
  1422. using pointer = typename std::allocator_traits<Alloc>::pointer;
  1423. using const_pointer = typename std::allocator_traits<Alloc>::const_pointer;
  1424. using void_pointer = typename std::allocator_traits<Alloc>::void_pointer;
  1425. using const_void_pointer =
  1426. typename std::allocator_traits<Alloc>::const_void_pointer;
  1427. using reference = typename std::allocator_traits<Alloc>::value_type &;
  1428. using const_reference =
  1429. const typename std::allocator_traits<Alloc>::value_type &;
  1430. using difference_type =
  1431. typename std::allocator_traits<Alloc>::difference_type;
  1432. using size_type = typename std::allocator_traits<Alloc>::size_type;
  1433. using propagate_on_container_copy_assignment = typename std::allocator_traits<
  1434. Alloc>::propagate_on_container_copy_assignment;
  1435. using propagate_on_container_move_assignment = typename std::allocator_traits<
  1436. Alloc>::propagate_on_container_move_assignment;
  1437. using propagate_on_container_swap =
  1438. typename std::allocator_traits<Alloc>::propagate_on_container_swap;
  1439. using is_always_equal =
  1440. typename std::allocator_traits<Alloc>::is_always_equal;
  1441. template <typename U>
  1442. struct rebind
  1443. {
  1444. typedef usm_allocator<U, AllocKind> other;
  1445. };
  1446. usm_allocator() : _impl(dpct::get_default_queue()) {}
  1447. ~usm_allocator() {}
  1448. usm_allocator(const usm_allocator &other) : _impl(other._impl) {}
  1449. usm_allocator(usm_allocator &&other) : _impl(std::move(other._impl)) {}
  1450. pointer address(reference r) { return &r; }
  1451. const_pointer address(const_reference r) { return &r; }
  1452. pointer allocate(size_type cnt, const_void_pointer hint = nullptr)
  1453. {
  1454. return std::allocator_traits<Alloc>::allocate(_impl, cnt, hint);
  1455. }
  1456. void deallocate(pointer p, size_type cnt)
  1457. {
  1458. std::allocator_traits<Alloc>::deallocate(_impl, p, cnt);
  1459. }
  1460. size_type max_size() const
  1461. {
  1462. return std::allocator_traits<Alloc>::max_size(_impl);
  1463. }
  1464. bool operator==(const usm_allocator &other) const { return _impl == other._impl; }
  1465. bool operator!=(const usm_allocator &other) const { return _impl != other._impl; }
  1466. };
  1467. } // namespace deprecated
  1468. inline void dpct_free(void *ptr,
  1469. const sycl::queue &q)
  1470. {
  1471. if (ptr)
  1472. {
  1473. sycl::free(ptr, q.get_context());
  1474. }
  1475. }
  1476. template <typename T>
  1477. inline auto get_memory(const void *x)
  1478. {
  1479. T *new_x = reinterpret_cast<T *>(const_cast<void *>(x));
  1480. return new_x;
  1481. }
  1482. template <typename T>
  1483. inline typename DataType<T>::T2 get_value(const T *s, sycl::queue &q)
  1484. {
  1485. using Ty = typename DataType<T>::T2;
  1486. Ty s_h;
  1487. if (get_pointer_attribute(q, s) == pointer_access_attribute::device_only)
  1488. detail::dpct_memcpy(q, (void *)&s_h, (const void *)s, sizeof(T), device_to_host)
  1489. .wait();
  1490. else
  1491. s_h = *reinterpret_cast<const Ty *>(s);
  1492. return s_h;
  1493. }
  1494. } // namespace detail
  1495. template <typename T>
  1496. inline auto get_value(const T *s, sycl::queue &q)
  1497. {
  1498. return detail::get_value(s, q);
  1499. }
  1500. namespace detail
  1501. {
  1502. template <class Ta, class Tb, class Tc, class Ts>
  1503. inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
  1504. oneapi::mkl::transpose b_trans, int m, int n, int k,
  1505. const void *alpha, const void *a, int lda, const void *b,
  1506. int ldb, const void *beta, void *c, int ldc)
  1507. {
  1508. Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
  1509. Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
  1510. auto data_a = get_memory<const Ta>(a);
  1511. auto data_b = get_memory<const Tb>(b);
  1512. auto data_c = get_memory<Tc>(c);
  1513. oneapi::mkl::blas::column_major::gemm(
  1514. q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
  1515. data_b, ldb, beta_value, data_c, ldc);
  1516. }
  1517. template <typename VecT, class BinaryOperation, class = void>
  1518. class vectorized_binary
  1519. {
  1520. public:
  1521. inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
  1522. {
  1523. VecT v4;
  1524. for (size_t i = 0; i < v4.size(); ++i)
  1525. {
  1526. v4[i] = binary_op(a[i], b[i]);
  1527. }
  1528. return v4;
  1529. }
  1530. };
  1531. template <typename VecT, class BinaryOperation>
  1532. class vectorized_binary<
  1533. VecT, BinaryOperation,
  1534. std::void_t<std::invoke_result_t<BinaryOperation, VecT, VecT>>>
  1535. {
  1536. public:
  1537. inline VecT operator()(VecT a, VecT b, const BinaryOperation binary_op)
  1538. {
  1539. return binary_op(a, b).template as<VecT>();
  1540. }
  1541. };
  1542. template <class Ta, class Tb, class Tc, class Ts>
  1543. inline void gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
  1544. oneapi::mkl::transpose b_trans, int m, int n, int k,
  1545. const void *alpha, const void **a, int lda,
  1546. const void **b, int ldb, const void *beta, void **c,
  1547. int ldc, int batch_size)
  1548. {
  1549. struct matrix_info_t
  1550. {
  1551. oneapi::mkl::transpose transpose_info[2];
  1552. Ts value_info[2];
  1553. std::int64_t size_info[3];
  1554. std::int64_t ld_info[3];
  1555. std::int64_t groupsize_info;
  1556. };
  1557. Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
  1558. Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
  1559. matrix_info_t *matrix_info =
  1560. (matrix_info_t *)std::malloc(sizeof(matrix_info_t));
  1561. matrix_info->transpose_info[0] = a_trans;
  1562. matrix_info->transpose_info[1] = b_trans;
  1563. matrix_info->value_info[0] = alpha_value;
  1564. matrix_info->value_info[1] = beta_value;
  1565. matrix_info->size_info[0] = m;
  1566. matrix_info->size_info[1] = n;
  1567. matrix_info->size_info[2] = k;
  1568. matrix_info->ld_info[0] = lda;
  1569. matrix_info->ld_info[1] = ldb;
  1570. matrix_info->ld_info[2] = ldc;
  1571. matrix_info->groupsize_info = batch_size;
  1572. sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
  1573. q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
  1574. matrix_info->size_info, matrix_info->size_info + 1,
  1575. matrix_info->size_info + 2, matrix_info->value_info,
  1576. reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
  1577. reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
  1578. matrix_info->value_info + 1, reinterpret_cast<Tc **>(c),
  1579. matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
  1580. q.submit([&](sycl::handler &cgh)
  1581. {
  1582. cgh.depends_on(e);
  1583. cgh.host_task([=] { std::free(matrix_info); }); });
  1584. }
  1585. template <class Ta, class Tb, class Tc, class Ts>
  1586. inline void
  1587. gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
  1588. oneapi::mkl::transpose b_trans, int m, int n,
  1589. int k, const void *alpha, const void *a, int lda,
  1590. long long int stride_a, const void *b, int ldb,
  1591. long long int stride_b, const void *beta, void *c,
  1592. int ldc, long long int stride_c, int batch_size)
  1593. {
  1594. Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
  1595. Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
  1596. auto data_a = get_memory<const Ta>(a);
  1597. auto data_b = get_memory<const Tb>(b);
  1598. auto data_c = get_memory<Tc>(c);
  1599. oneapi::mkl::blas::column_major::gemm_batch(
  1600. q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
  1601. stride_a, data_b, ldb, stride_b, beta_value,
  1602. data_c, ldc, stride_c, batch_size);
  1603. }
  1604. } // namespace detail
  1605. template <typename VecT, class BinaryOperation>
  1606. inline unsigned vectorized_binary(unsigned a, unsigned b,
  1607. const BinaryOperation binary_op)
  1608. {
  1609. sycl::vec<unsigned, 1> v0{a}, v1{b};
  1610. auto v2 = v0.as<VecT>();
  1611. auto v3 = v1.as<VecT>();
  1612. auto v4 =
  1613. detail::vectorized_binary<VecT, BinaryOperation>()(v2, v3, binary_op);
  1614. v0 = v4.template as<sycl::vec<unsigned, 1>>();
  1615. return v0;
  1616. }
  1617. static void async_dpct_memcpy(void *to_ptr, const void *from_ptr, size_t size,
  1618. memcpy_direction direction = automatic,
  1619. sycl::queue &q = dpct::get_default_queue())
  1620. {
  1621. detail::dpct_memcpy(q, to_ptr, from_ptr, size, direction);
  1622. }
  1623. static inline unsigned int select_device(unsigned int id)
  1624. {
  1625. dev_mgr::instance().select_device(id);
  1626. return id;
  1627. }
  1628. template <typename T>
  1629. T permute_sub_group_by_xor(sycl::sub_group g, T x, unsigned int mask,
  1630. unsigned int logical_sub_group_size = 32)
  1631. {
  1632. unsigned int id = g.get_local_linear_id();
  1633. unsigned int start_index =
  1634. id / logical_sub_group_size * logical_sub_group_size;
  1635. unsigned int target_offset = (id % logical_sub_group_size) ^ mask;
  1636. return sycl::select_from_group(g, x,
  1637. target_offset < logical_sub_group_size
  1638. ? start_index + target_offset
  1639. : id);
  1640. }
  1641. template <typename T>
  1642. sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val)
  1643. {
  1644. return sycl::vec<T, 1>(val)
  1645. .template as<sycl::vec<
  1646. std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>, 4>>()
  1647. .template convert<T>();
  1648. }
  1649. template <typename T1, typename T2>
  1650. using dot_product_acc_t =
  1651. std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
  1652. uint32_t, int32_t>;
  1653. template <typename T1, typename T2, typename T3>
  1654. inline auto dp4a(T1 a, T2 b, T3 c)
  1655. {
  1656. dot_product_acc_t<T1, T2> res = c;
  1657. auto va = extract_and_sign_or_zero_extend4(a);
  1658. auto vb = extract_and_sign_or_zero_extend4(b);
  1659. res += va[0] * vb[0];
  1660. res += va[1] * vb[1];
  1661. res += va[2] * vb[2];
  1662. res += va[3] * vb[3];
  1663. return res;
  1664. }
  1665. struct sub_sat
  1666. {
  1667. template <typename T>
  1668. auto operator()(const T x, const T y) const
  1669. {
  1670. return sycl::sub_sat(x, y);
  1671. }
  1672. };
  1673. template <typename S, typename T>
  1674. inline T vectorized_min(T a, T b)
  1675. {
  1676. sycl::vec<T, 1> v0{a}, v1{b};
  1677. auto v2 = v0.template as<S>();
  1678. auto v3 = v1.template as<S>();
  1679. auto v4 = sycl::min(v2, v3);
  1680. v0 = v4.template as<sycl::vec<T, 1>>();
  1681. return v0;
  1682. }
  1683. inline float pow(const float a, const int b) { return sycl::pown(a, b); }
  1684. inline double pow(const double a, const int b) { return sycl::pown(a, b); }
  1685. inline float pow(const float a, const float b) { return sycl::pow(a, b); }
  1686. inline double pow(const double a, const double b) { return sycl::pow(a, b); }
  1687. template <typename T, typename U>
  1688. inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
  1689. pow(const T a, const U b)
  1690. {
  1691. return sycl::pow(a, static_cast<T>(b));
  1692. }
  1693. template <typename T, typename U>
  1694. inline typename std::enable_if_t<!std::is_floating_point_v<T>, double>
  1695. pow(const T a, const U b)
  1696. {
  1697. return sycl::pow(static_cast<double>(a), static_cast<double>(b));
  1698. }
  1699. inline double min(const double a, const float b)
  1700. {
  1701. return sycl::fmin(a, static_cast<double>(b));
  1702. }
  1703. inline double min(const float a, const double b)
  1704. {
  1705. return sycl::fmin(static_cast<double>(a), b);
  1706. }
  1707. inline float min(const float a, const float b) { return sycl::fmin(a, b); }
  1708. inline double min(const double a, const double b) { return sycl::fmin(a, b); }
  1709. inline std::uint32_t min(const std::uint32_t a, const std::int32_t b)
  1710. {
  1711. return sycl::min(a, static_cast<std::uint32_t>(b));
  1712. }
  1713. inline std::uint32_t min(const std::int32_t a, const std::uint32_t b)
  1714. {
  1715. return sycl::min(static_cast<std::uint32_t>(a), b);
  1716. }
  1717. inline std::int32_t min(const std::int32_t a, const std::int32_t b)
  1718. {
  1719. return sycl::min(a, b);
  1720. }
  1721. inline std::uint32_t min(const std::uint32_t a, const std::uint32_t b)
  1722. {
  1723. return sycl::min(a, b);
  1724. }
  1725. inline std::uint64_t min(const std::uint64_t a, const std::int64_t b)
  1726. {
  1727. return sycl::min(a, static_cast<std::uint64_t>(b));
  1728. }
  1729. inline std::uint64_t min(const std::int64_t a, const std::uint64_t b)
  1730. {
  1731. return sycl::min(static_cast<std::uint64_t>(a), b);
  1732. }
  1733. inline std::int64_t min(const std::int64_t a, const std::int64_t b)
  1734. {
  1735. return sycl::min(a, b);
  1736. }
  1737. inline std::uint64_t min(const std::uint64_t a, const std::uint64_t b)
  1738. {
  1739. return sycl::min(a, b);
  1740. }
  1741. inline std::uint64_t min(const std::uint64_t a, const std::int32_t b)
  1742. {
  1743. return sycl::min(a, static_cast<std::uint64_t>(b));
  1744. }
  1745. inline std::uint64_t min(const std::int32_t a, const std::uint64_t b)
  1746. {
  1747. return sycl::min(static_cast<std::uint64_t>(a), b);
  1748. }
  1749. inline std::uint64_t min(const std::uint64_t a, const std::uint32_t b)
  1750. {
  1751. return sycl::min(a, static_cast<std::uint64_t>(b));
  1752. }
  1753. inline std::uint64_t min(const std::uint32_t a, const std::uint64_t b)
  1754. {
  1755. return sycl::min(static_cast<std::uint64_t>(a), b);
  1756. }
  1757. // max function overloads.
  1758. // For floating-point types, `float` or `double` arguments are acceptable.
  1759. // For integer types, `std::uint32_t`, `std::int32_t`, `std::uint64_t` or
  1760. // `std::int64_t` type arguments are acceptable.
  1761. inline double max(const double a, const float b)
  1762. {
  1763. return sycl::fmax(a, static_cast<double>(b));
  1764. }
  1765. inline double max(const float a, const double b)
  1766. {
  1767. return sycl::fmax(static_cast<double>(a), b);
  1768. }
  1769. inline float max(const float a, const float b) { return sycl::fmax(a, b); }
  1770. inline double max(const double a, const double b) { return sycl::fmax(a, b); }
  1771. inline std::uint32_t max(const std::uint32_t a, const std::int32_t b)
  1772. {
  1773. return sycl::max(a, static_cast<std::uint32_t>(b));
  1774. }
  1775. inline std::uint32_t max(const std::int32_t a, const std::uint32_t b)
  1776. {
  1777. return sycl::max(static_cast<std::uint32_t>(a), b);
  1778. }
  1779. inline std::int32_t max(const std::int32_t a, const std::int32_t b)
  1780. {
  1781. return sycl::max(a, b);
  1782. }
  1783. inline std::uint32_t max(const std::uint32_t a, const std::uint32_t b)
  1784. {
  1785. return sycl::max(a, b);
  1786. }
  1787. inline std::uint64_t max(const std::uint64_t a, const std::int64_t b)
  1788. {
  1789. return sycl::max(a, static_cast<std::uint64_t>(b));
  1790. }
  1791. inline std::uint64_t max(const std::int64_t a, const std::uint64_t b)
  1792. {
  1793. return sycl::max(static_cast<std::uint64_t>(a), b);
  1794. }
  1795. inline std::int64_t max(const std::int64_t a, const std::int64_t b)
  1796. {
  1797. return sycl::max(a, b);
  1798. }
  1799. inline std::uint64_t max(const std::uint64_t a, const std::uint64_t b)
  1800. {
  1801. return sycl::max(a, b);
  1802. }
  1803. inline std::uint64_t max(const std::uint64_t a, const std::int32_t b)
  1804. {
  1805. return sycl::max(a, static_cast<std::uint64_t>(b));
  1806. }
  1807. inline std::uint64_t max(const std::int32_t a, const std::uint64_t b)
  1808. {
  1809. return sycl::max(static_cast<std::uint64_t>(a), b);
  1810. }
  1811. inline std::uint64_t max(const std::uint64_t a, const std::uint32_t b)
  1812. {
  1813. return sycl::max(a, static_cast<std::uint64_t>(b));
  1814. }
  1815. inline std::uint64_t max(const std::uint32_t a, const std::uint64_t b)
  1816. {
  1817. return sycl::max(static_cast<std::uint64_t>(a), b);
  1818. }
  1819. inline void
  1820. has_capability_or_fail(const sycl::device &dev,
  1821. const std::initializer_list<sycl::aspect> &props)
  1822. {
  1823. for (const auto &it : props)
  1824. {
  1825. if (dev.has(it))
  1826. continue;
  1827. switch (it)
  1828. {
  1829. case sycl::aspect::fp64:
  1830. throw std::runtime_error("'double' is not supported in '" +
  1831. dev.get_info<sycl::info::device::name>() +
  1832. "' device");
  1833. break;
  1834. case sycl::aspect::fp16:
  1835. throw std::runtime_error("'half' is not supported in '" +
  1836. dev.get_info<sycl::info::device::name>() +
  1837. "' device");
  1838. break;
  1839. default:
  1840. #define __SYCL_ASPECT(ASPECT, ID) \
  1841. case sycl::aspect::ASPECT: \
  1842. return #ASPECT;
  1843. #define __SYCL_ASPECT_DEPRECATED(ASPECT, ID, MESSAGE) __SYCL_ASPECT(ASPECT, ID)
  1844. #define __SYCL_ASPECT_DEPRECATED_ALIAS(ASPECT, ID, MESSAGE)
  1845. auto getAspectNameStr = [](sycl::aspect AspectNum) -> std::string
  1846. {
  1847. switch (AspectNum)
  1848. {
  1849. #include <sycl/info/aspects.def>
  1850. #include <sycl/info/aspects_deprecated.def>
  1851. default:
  1852. return "unknown aspect";
  1853. }
  1854. };
  1855. #undef __SYCL_ASPECT_DEPRECATED_ALIAS
  1856. #undef __SYCL_ASPECT_DEPRECATED
  1857. #undef __SYCL_ASPECT
  1858. throw std::runtime_error(
  1859. "'" + getAspectNameStr(it) + "' is not supported in '" +
  1860. dev.get_info<sycl::info::device::name>() + "' device");
  1861. }
  1862. break;
  1863. }
  1864. }
  1865. static inline unsigned int get_current_device_id()
  1866. {
  1867. return dev_mgr::instance().current_device_id();
  1868. }
  1869. static inline device_ext &get_current_device()
  1870. {
  1871. return dev_mgr::instance().current_device();
  1872. }
  1873. static inline sycl::queue &get_in_order_queue()
  1874. {
  1875. return dev_mgr::instance().current_device().in_order_queue();
  1876. }
  1877. static sycl::event
  1878. dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr, size_t size,
  1879. memcpy_direction direction,
  1880. const std::vector<sycl::event> &dep_events = {})
  1881. {
  1882. if (!size)
  1883. return sycl::event{};
  1884. return q.memcpy(to_ptr, from_ptr, size, dep_events);
  1885. GGML_UNUSED(direction);
  1886. }
  1887. // Get actual copy range and make sure it will not exceed range.
  1888. static inline size_t get_copy_range(sycl::range<3> size, size_t slice,
  1889. size_t pitch)
  1890. {
  1891. return slice * (size.get(2) - 1) + pitch * (size.get(1) - 1) + size.get(0);
  1892. }
  1893. static inline size_t get_offset(sycl::id<3> id, size_t slice,
  1894. size_t pitch)
  1895. {
  1896. return slice * id.get(2) + pitch * id.get(1) + id.get(0);
  1897. }
  1898. /// copy 3D matrix specified by \p size from 3D matrix specified by \p from_ptr
  1899. /// and \p from_range to another specified by \p to_ptr and \p to_range.
  1900. static inline std::vector<sycl::event>
  1901. dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
  1902. sycl::range<3> to_range, sycl::range<3> from_range,
  1903. sycl::id<3> to_id, sycl::id<3> from_id,
  1904. sycl::range<3> size, memcpy_direction direction,
  1905. const std::vector<sycl::event> &dep_events = {})
  1906. {
  1907. // RAII for host pointer
  1908. class host_buffer
  1909. {
  1910. void *_buf;
  1911. size_t _size;
  1912. sycl::queue &_q;
  1913. const std::vector<sycl::event> &_deps; // free operation depends
  1914. public:
  1915. host_buffer(size_t size, sycl::queue &q,
  1916. const std::vector<sycl::event> &deps)
  1917. : _buf(std::malloc(size)), _size(size), _q(q), _deps(deps) {}
  1918. void *get_ptr() const { return _buf; }
  1919. size_t get_size() const { return _size; }
  1920. ~host_buffer()
  1921. {
  1922. if (_buf)
  1923. {
  1924. _q.submit([&](sycl::handler &cgh)
  1925. {
  1926. cgh.depends_on(_deps);
  1927. cgh.host_task([buf = _buf] { std::free(buf); }); });
  1928. }
  1929. }
  1930. };
  1931. std::vector<sycl::event> event_list;
  1932. size_t to_slice = to_range.get(1) * to_range.get(0),
  1933. from_slice = from_range.get(1) * from_range.get(0);
  1934. unsigned char *to_surface =
  1935. (unsigned char *)to_ptr + get_offset(to_id, to_slice, to_range.get(0));
  1936. const unsigned char *from_surface =
  1937. (const unsigned char *)from_ptr +
  1938. get_offset(from_id, from_slice, from_range.get(0));
  1939. if (to_slice == from_slice && to_slice == size.get(1) * size.get(0))
  1940. {
  1941. return {dpct_memcpy(q, to_surface, from_surface, to_slice * size.get(2),
  1942. direction, dep_events)};
  1943. }
  1944. direction = detail::deduce_memcpy_direction(q, to_ptr, from_ptr, direction);
  1945. size_t size_slice = size.get(1) * size.get(0);
  1946. switch (direction)
  1947. {
  1948. case host_to_host:
  1949. for (size_t z = 0; z < size.get(2); ++z)
  1950. {
  1951. unsigned char *to_ptr = to_surface;
  1952. const unsigned char *from_ptr = from_surface;
  1953. if (to_range.get(0) == from_range.get(0) &&
  1954. to_range.get(0) == size.get(0))
  1955. {
  1956. event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size_slice,
  1957. direction, dep_events));
  1958. }
  1959. else
  1960. {
  1961. for (size_t y = 0; y < size.get(1); ++y)
  1962. {
  1963. event_list.push_back(dpct_memcpy(q, to_ptr, from_ptr, size.get(0),
  1964. direction, dep_events));
  1965. to_ptr += to_range.get(0);
  1966. from_ptr += from_range.get(0);
  1967. }
  1968. }
  1969. to_surface += to_slice;
  1970. from_surface += from_slice;
  1971. }
  1972. break;
  1973. case host_to_device:
  1974. {
  1975. host_buffer buf(get_copy_range(size, to_slice, to_range.get(0)), q,
  1976. event_list);
  1977. std::vector<sycl::event> host_events;
  1978. if (to_slice == size_slice)
  1979. {
  1980. // Copy host data to a temp host buffer with the shape of target.
  1981. host_events =
  1982. dpct_memcpy(q, buf.get_ptr(), from_surface, to_range, from_range,
  1983. sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size,
  1984. host_to_host, dep_events);
  1985. }
  1986. else
  1987. {
  1988. // Copy host data to a temp host buffer with the shape of target.
  1989. host_events = dpct_memcpy(
  1990. q, buf.get_ptr(), from_surface, to_range, from_range,
  1991. sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0), size, host_to_host,
  1992. // If has padding data, not sure whether it is useless. So fill temp
  1993. // buffer with it.
  1994. std::vector<sycl::event>{
  1995. dpct_memcpy(q, buf.get_ptr(), to_surface, buf.get_size(),
  1996. device_to_host, dep_events)});
  1997. }
  1998. // Copy from temp host buffer to device with only one submit.
  1999. event_list.push_back(dpct_memcpy(q, to_surface, buf.get_ptr(),
  2000. buf.get_size(), host_to_device,
  2001. host_events));
  2002. break;
  2003. }
  2004. case device_to_host:
  2005. {
  2006. host_buffer buf(get_copy_range(size, from_slice, from_range.get(0)), q,
  2007. event_list);
  2008. // Copy from host temp buffer to host target with reshaping.
  2009. event_list = dpct_memcpy(
  2010. q, to_surface, buf.get_ptr(), to_range, from_range, sycl::id<3>(0, 0, 0),
  2011. sycl::id<3>(0, 0, 0), size, host_to_host,
  2012. // Copy from device to temp host buffer with only one submit.
  2013. std::vector<sycl::event>{dpct_memcpy(q, buf.get_ptr(), from_surface,
  2014. buf.get_size(),
  2015. device_to_host, dep_events)});
  2016. break;
  2017. }
  2018. case device_to_device:
  2019. event_list.push_back(q.submit([&](sycl::handler &cgh)
  2020. {
  2021. cgh.depends_on(dep_events);
  2022. cgh.parallel_for<class dpct_memcpy_3d_detail>(
  2023. size,
  2024. [=](sycl::id<3> id) {
  2025. to_surface[get_offset(id, to_slice, to_range.get(0))] =
  2026. from_surface[get_offset(id, from_slice, from_range.get(0))];
  2027. }); }));
  2028. break;
  2029. default:
  2030. throw std::runtime_error("dpct_memcpy: invalid direction value");
  2031. }
  2032. return event_list;
  2033. }
  2034. /// memcpy 2D/3D matrix specified by pitched_data.
  2035. static inline std::vector<sycl::event>
  2036. dpct_memcpy(sycl::queue &q, pitched_data to, sycl::id<3> to_id,
  2037. pitched_data from, sycl::id<3> from_id, sycl::range<3> size,
  2038. memcpy_direction direction = automatic)
  2039. {
  2040. return dpct_memcpy(q, to.get_data_ptr(), from.get_data_ptr(),
  2041. sycl::range<3>(to.get_pitch(), to.get_y(), 1),
  2042. sycl::range<3>(from.get_pitch(), from.get_y(), 1), to_id, from_id,
  2043. size, direction);
  2044. }
  2045. /// memcpy 2D matrix with pitch.
  2046. static inline std::vector<sycl::event>
  2047. dpct_memcpy(sycl::queue &q, void *to_ptr, const void *from_ptr,
  2048. size_t to_pitch, size_t from_pitch, size_t x, size_t y,
  2049. memcpy_direction direction = automatic)
  2050. {
  2051. return dpct_memcpy(q, to_ptr, from_ptr, sycl::range<3>(to_pitch, y, 1),
  2052. sycl::range<3>(from_pitch, y, 1),
  2053. sycl::id<3>(0, 0, 0), sycl::id<3>(0, 0, 0),
  2054. sycl::range<3>(x, y, 1), direction);
  2055. }
  2056. inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans,
  2057. oneapi::mkl::transpose b_trans, int m, int n, int k,
  2058. const void *alpha, const void *a, library_data_t a_type,
  2059. int lda, const void *b, library_data_t b_type, int ldb,
  2060. const void *beta, void *c, library_data_t c_type, int ldc,
  2061. library_data_t scaling_type)
  2062. {
  2063. if (scaling_type == library_data_t::real_float &&
  2064. c_type == library_data_t::complex_float)
  2065. {
  2066. scaling_type = library_data_t::complex_float;
  2067. }
  2068. else if (scaling_type == library_data_t::real_double &&
  2069. c_type == library_data_t::complex_double)
  2070. {
  2071. scaling_type = library_data_t::complex_double;
  2072. }
  2073. std::uint64_t key =
  2074. detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
  2075. switch (key)
  2076. {
  2077. case detail::get_type_combination_id(
  2078. library_data_t::real_float, library_data_t::real_float,
  2079. library_data_t::real_float, library_data_t::real_float):
  2080. {
  2081. detail::gemm_impl<float, float, float, float>(
  2082. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  2083. break;
  2084. }
  2085. case detail::get_type_combination_id(
  2086. library_data_t::real_double, library_data_t::real_double,
  2087. library_data_t::real_double, library_data_t::real_double):
  2088. {
  2089. detail::gemm_impl<double, double, double, double>(
  2090. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  2091. break;
  2092. }
  2093. case detail::get_type_combination_id(
  2094. library_data_t::complex_float, library_data_t::complex_float,
  2095. library_data_t::complex_float, library_data_t::complex_float):
  2096. {
  2097. detail::gemm_impl<std::complex<float>, std::complex<float>,
  2098. std::complex<float>, std::complex<float>>(
  2099. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  2100. break;
  2101. }
  2102. case detail::get_type_combination_id(
  2103. library_data_t::complex_double, library_data_t::complex_double,
  2104. library_data_t::complex_double, library_data_t::complex_double):
  2105. {
  2106. detail::gemm_impl<std::complex<double>, std::complex<double>,
  2107. std::complex<double>, std::complex<double>>(
  2108. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  2109. break;
  2110. }
  2111. case detail::get_type_combination_id(
  2112. library_data_t::real_half, library_data_t::real_half,
  2113. library_data_t::real_half, library_data_t::real_half):
  2114. {
  2115. detail::gemm_impl<sycl::half, sycl::half, sycl::half,
  2116. sycl::half>(q, a_trans, b_trans, m, n, k, alpha, a,
  2117. lda, b, ldb, beta, c, ldc);
  2118. break;
  2119. }
  2120. #ifdef __INTEL_MKL__
  2121. case detail::get_type_combination_id(
  2122. library_data_t::real_bfloat16, library_data_t::real_bfloat16,
  2123. library_data_t::real_float, library_data_t::real_float):
  2124. {
  2125. detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
  2126. float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
  2127. ldb, beta, c, ldc);
  2128. break;
  2129. }
  2130. case detail::get_type_combination_id(
  2131. library_data_t::real_half, library_data_t::real_half,
  2132. library_data_t::real_float, library_data_t::real_float):
  2133. {
  2134. detail::gemm_impl<sycl::half, sycl::half, float, float>(
  2135. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  2136. break;
  2137. }
  2138. case detail::get_type_combination_id(
  2139. library_data_t::real_half, library_data_t::real_half,
  2140. library_data_t::real_half, library_data_t::real_float):
  2141. {
  2142. float alpha_value =
  2143. dpct::get_value(reinterpret_cast<const float *>(alpha), q);
  2144. float beta_value =
  2145. dpct::get_value(reinterpret_cast<const float *>(beta), q);
  2146. sycl::half alpha_half(alpha_value);
  2147. sycl::half beta_half(beta_value);
  2148. detail::gemm_impl<sycl::half, sycl::half, sycl::half,
  2149. sycl::half>(q, a_trans, b_trans, m, n, k, &alpha_half,
  2150. a, lda, b, ldb, &beta_half, c, ldc);
  2151. break;
  2152. }
  2153. case detail::get_type_combination_id(
  2154. library_data_t::real_int8, library_data_t::real_int8,
  2155. library_data_t::real_float, library_data_t::real_float):
  2156. {
  2157. detail::gemm_impl<std::int8_t, std::int8_t, float, float>(
  2158. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  2159. break;
  2160. }
  2161. case detail::get_type_combination_id(
  2162. library_data_t::real_bfloat16, library_data_t::real_bfloat16,
  2163. library_data_t::real_bfloat16, library_data_t::real_float):
  2164. {
  2165. detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
  2166. oneapi::mkl::bfloat16, float>(
  2167. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
  2168. break;
  2169. }
  2170. case detail::get_type_combination_id(
  2171. library_data_t::real_int8, library_data_t::real_int8,
  2172. library_data_t::real_int32, library_data_t::real_int32):
  2173. {
  2174. float alpha_float =
  2175. dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
  2176. float beta_float =
  2177. dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
  2178. detail::gemm_impl<std::int8_t, std::int8_t, std::int32_t, float>(
  2179. q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
  2180. break;
  2181. }
  2182. #endif // __INTEL_MKL__
  2183. default:
  2184. throw std::runtime_error("the combination of data type is unsupported");
  2185. }
  2186. } // gemm()
  2187. /// Computes a batch of matrix-matrix product with general matrices.
  2188. /// \param [in] q The queue where the routine should be executed.
  2189. /// \param [in] a_trans Specifies the operation applied to A.
  2190. /// \param [in] b_trans Specifies the operation applied to B.
  2191. /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
  2192. /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
  2193. /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
  2194. /// \param [in] alpha Scaling factor for the matrix-matrix product.
  2195. /// \param [in] a Input matrix A.
  2196. /// \param [in] a_type Data type of the matrix A.
  2197. /// \param [in] lda Leading dimension of A.
  2198. /// \param [in] b Input matrix B.
  2199. /// \param [in] b_type Data type of the matrix B.
  2200. /// \param [in] ldb Leading dimension of B.
  2201. /// \param [in] beta Scaling factor for matrix C.
  2202. /// \param [in, out] c Input/Output matrix C.
  2203. /// \param [in] c_type Data type of the matrix C.
  2204. /// \param [in] ldc Leading dimension of C.
  2205. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
  2206. /// \param [in] scaling_type Data type of the scaling factors.
  2207. inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
  2208. oneapi::mkl::transpose b_trans, int m, int n, int k,
  2209. const void *alpha, const void *a[],
  2210. library_data_t a_type, int lda, const void *b[],
  2211. library_data_t b_type, int ldb, const void *beta,
  2212. void *c[], library_data_t c_type, int ldc,
  2213. int batch_size, library_data_t scaling_type)
  2214. {
  2215. if (scaling_type == library_data_t::real_float &&
  2216. c_type == library_data_t::complex_float)
  2217. {
  2218. scaling_type = library_data_t::complex_float;
  2219. }
  2220. else if (scaling_type == library_data_t::real_double &&
  2221. c_type == library_data_t::complex_double)
  2222. {
  2223. scaling_type = library_data_t::complex_double;
  2224. }
  2225. std::uint64_t key =
  2226. detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
  2227. switch (key)
  2228. {
  2229. case detail::get_type_combination_id(
  2230. library_data_t::real_float, library_data_t::real_float,
  2231. library_data_t::real_float, library_data_t::real_float):
  2232. {
  2233. detail::gemm_batch_impl<float, float, float, float>(
  2234. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
  2235. batch_size);
  2236. break;
  2237. }
  2238. case detail::get_type_combination_id(
  2239. library_data_t::real_double, library_data_t::real_double,
  2240. library_data_t::real_double, library_data_t::real_double):
  2241. {
  2242. detail::gemm_batch_impl<double, double, double, double>(
  2243. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
  2244. batch_size);
  2245. break;
  2246. }
  2247. case detail::get_type_combination_id(
  2248. library_data_t::complex_float, library_data_t::complex_float,
  2249. library_data_t::complex_float, library_data_t::complex_float):
  2250. {
  2251. detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
  2252. std::complex<float>, std::complex<float>>(
  2253. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
  2254. batch_size);
  2255. break;
  2256. }
  2257. case detail::get_type_combination_id(
  2258. library_data_t::complex_double, library_data_t::complex_double,
  2259. library_data_t::complex_double, library_data_t::complex_double):
  2260. {
  2261. detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
  2262. std::complex<double>, std::complex<double>>(
  2263. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
  2264. batch_size);
  2265. break;
  2266. }
  2267. case detail::get_type_combination_id(
  2268. library_data_t::real_half, library_data_t::real_half,
  2269. library_data_t::real_half, library_data_t::real_half):
  2270. {
  2271. detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
  2272. sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
  2273. a, lda, b, ldb, beta, c, ldc,
  2274. batch_size);
  2275. break;
  2276. }
  2277. #ifdef __INTEL_MKL__
  2278. case detail::get_type_combination_id(
  2279. library_data_t::real_bfloat16, library_data_t::real_bfloat16,
  2280. library_data_t::real_bfloat16, library_data_t::real_float):
  2281. {
  2282. detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
  2283. oneapi::mkl::bfloat16, float>(
  2284. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
  2285. batch_size);
  2286. break;
  2287. }
  2288. case detail::get_type_combination_id(
  2289. library_data_t::real_bfloat16, library_data_t::real_bfloat16,
  2290. library_data_t::real_float, library_data_t::real_float):
  2291. {
  2292. detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
  2293. float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
  2294. b, ldb, beta, c, ldc, batch_size);
  2295. break;
  2296. }
  2297. case detail::get_type_combination_id(
  2298. library_data_t::real_int8, library_data_t::real_int8,
  2299. library_data_t::real_int32, library_data_t::real_int32):
  2300. {
  2301. float alpha_float =
  2302. dpct::get_value(reinterpret_cast<const std::int32_t *>(alpha), q);
  2303. float beta_float =
  2304. dpct::get_value(reinterpret_cast<const std::int32_t *>(beta), q);
  2305. detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
  2306. float>(q, a_trans, b_trans, m, n, k, &alpha_float,
  2307. a, lda, b, ldb, &beta_float, c, ldc,
  2308. batch_size);
  2309. break;
  2310. }
  2311. case detail::get_type_combination_id(
  2312. library_data_t::real_int8, library_data_t::real_int8,
  2313. library_data_t::real_float, library_data_t::real_float):
  2314. {
  2315. detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
  2316. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
  2317. batch_size);
  2318. break;
  2319. }
  2320. case detail::get_type_combination_id(
  2321. library_data_t::real_half, library_data_t::real_half,
  2322. library_data_t::real_float, library_data_t::real_float):
  2323. {
  2324. detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
  2325. q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
  2326. batch_size);
  2327. break;
  2328. }
  2329. #endif
  2330. case detail::get_type_combination_id(
  2331. library_data_t::real_half, library_data_t::real_half,
  2332. library_data_t::real_half, library_data_t::real_float):
  2333. {
  2334. float alpha_value =
  2335. dpct::get_value(reinterpret_cast<const float *>(alpha), q);
  2336. float beta_value =
  2337. dpct::get_value(reinterpret_cast<const float *>(beta), q);
  2338. sycl::half alpha_half(alpha_value);
  2339. sycl::half beta_half(beta_value);
  2340. detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
  2341. q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, b, ldb, &beta_half, c, ldc,
  2342. batch_size);
  2343. break;
  2344. }
  2345. default:
  2346. throw std::runtime_error("the combination of data type is unsupported");
  2347. }
  2348. }
  2349. /// Computes a batch of matrix-matrix product with general matrices.
  2350. /// \param [in] q The queue where the routine should be executed.
  2351. /// \param [in] a_trans Specifies the operation applied to A.
  2352. /// \param [in] b_trans Specifies the operation applied to B.
  2353. /// \param [in] m Specifies the number of rows of the matrix op(A) and of the matrix C.
  2354. /// \param [in] n Specifies the number of columns of the matrix op(B) and of the matrix C.
  2355. /// \param [in] k Specifies the number of columns of the matrix op(A) and the number of rows of the matrix op(B).
  2356. /// \param [in] alpha Scaling factor for the matrix-matrix product.
  2357. /// \param [in] a Input matrix A.
  2358. /// \param [in] a_type Data type of the matrix A.
  2359. /// \param [in] lda Leading dimension of A.
  2360. /// \param [in] stride_a Stride between the different A matrices.
  2361. /// \param [in] b Input matrix B.
  2362. /// \param [in] b_type Data type of the matrix B.
  2363. /// \param [in] ldb Leading dimension of B.
  2364. /// \param [in] stride_b Stride between the different B matrices.
  2365. /// \param [in] beta Scaling factor for matrix C.
  2366. /// \param [in, out] c Input/Output matrix C.
  2367. /// \param [in] c_type Data type of the matrix C.
  2368. /// \param [in] ldc Leading dimension of C.
  2369. /// \param [in] stride_c Stride between the different C matrices.
  2370. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
  2371. /// \param [in] scaling_type Data type of the scaling factors.
  2372. inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
  2373. oneapi::mkl::transpose b_trans, int m, int n, int k,
  2374. const void *alpha, const void *a, library_data_t a_type,
  2375. int lda, long long int stride_a, const void *b,
  2376. library_data_t b_type, int ldb, long long int stride_b,
  2377. const void *beta, void *c, library_data_t c_type,
  2378. int ldc, long long int stride_c, int batch_size,
  2379. library_data_t scaling_type)
  2380. {
  2381. if (scaling_type == library_data_t::real_float &&
  2382. c_type == library_data_t::complex_float)
  2383. {
  2384. scaling_type = library_data_t::complex_float;
  2385. }
  2386. else if (scaling_type == library_data_t::real_double &&
  2387. c_type == library_data_t::complex_double)
  2388. {
  2389. scaling_type = library_data_t::complex_double;
  2390. }
  2391. std::uint64_t key =
  2392. detail::get_type_combination_id(a_type, b_type, c_type, scaling_type);
  2393. switch (key)
  2394. {
  2395. case detail::get_type_combination_id(
  2396. library_data_t::real_float, library_data_t::real_float,
  2397. library_data_t::real_float, library_data_t::real_float):
  2398. {
  2399. detail::gemm_batch_impl<float, float, float, float>(
  2400. q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
  2401. beta, c, ldc, stride_c, batch_size);
  2402. break;
  2403. }
  2404. case detail::get_type_combination_id(
  2405. library_data_t::real_double, library_data_t::real_double,
  2406. library_data_t::real_double, library_data_t::real_double):
  2407. {
  2408. detail::gemm_batch_impl<double, double, double, double>(
  2409. q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
  2410. beta, c, ldc, stride_c, batch_size);
  2411. break;
  2412. }
  2413. case detail::get_type_combination_id(
  2414. library_data_t::complex_float, library_data_t::complex_float,
  2415. library_data_t::complex_float, library_data_t::complex_float):
  2416. {
  2417. detail::gemm_batch_impl<std::complex<float>, std::complex<float>,
  2418. std::complex<float>, std::complex<float>>(
  2419. q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
  2420. beta, c, ldc, stride_c, batch_size);
  2421. break;
  2422. }
  2423. case detail::get_type_combination_id(
  2424. library_data_t::complex_double, library_data_t::complex_double,
  2425. library_data_t::complex_double, library_data_t::complex_double):
  2426. {
  2427. detail::gemm_batch_impl<std::complex<double>, std::complex<double>,
  2428. std::complex<double>, std::complex<double>>(
  2429. q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
  2430. beta, c, ldc, stride_c, batch_size);
  2431. break;
  2432. }
  2433. case detail::get_type_combination_id(
  2434. library_data_t::real_half, library_data_t::real_half,
  2435. library_data_t::real_half, library_data_t::real_half):
  2436. {
  2437. detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half,
  2438. sycl::half>(q, a_trans, b_trans, m, n, k, alpha,
  2439. a, lda, stride_a, b, ldb, stride_b,
  2440. beta, c, ldc, stride_c, batch_size);
  2441. break;
  2442. }
  2443. #ifdef __INTEL_MKL__
  2444. case detail::get_type_combination_id(
  2445. library_data_t::real_bfloat16, library_data_t::real_bfloat16,
  2446. library_data_t::real_bfloat16, library_data_t::real_float):
  2447. {
  2448. detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
  2449. oneapi::mkl::bfloat16, float>(
  2450. q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
  2451. beta, c, ldc, stride_c, batch_size);
  2452. break;
  2453. }
  2454. case detail::get_type_combination_id(
  2455. library_data_t::real_bfloat16, library_data_t::real_bfloat16,
  2456. library_data_t::real_float, library_data_t::real_float):
  2457. {
  2458. detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
  2459. float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
  2460. stride_a, b, ldb, stride_b, beta, c, ldc,
  2461. stride_c, batch_size);
  2462. break;
  2463. }
  2464. case detail::get_type_combination_id(
  2465. library_data_t::real_int8, library_data_t::real_int8,
  2466. library_data_t::real_int32, library_data_t::real_int32):
  2467. {
  2468. detail::gemm_batch_impl<std::int8_t, std::int8_t, std::int32_t,
  2469. std::int32_t>(q, a_trans, b_trans, m, n, k, alpha,
  2470. a, lda, stride_a, b, ldb, stride_b,
  2471. beta, c, ldc, stride_c, batch_size);
  2472. break;
  2473. }
  2474. case detail::get_type_combination_id(
  2475. library_data_t::real_int8, library_data_t::real_int8,
  2476. library_data_t::real_float, library_data_t::real_float):
  2477. {
  2478. detail::gemm_batch_impl<std::int8_t, std::int8_t, float, float>(
  2479. q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
  2480. beta, c, ldc, stride_c, batch_size);
  2481. break;
  2482. }
  2483. case detail::get_type_combination_id(
  2484. library_data_t::real_half, library_data_t::real_half,
  2485. library_data_t::real_float, library_data_t::real_float):
  2486. {
  2487. detail::gemm_batch_impl<sycl::half, sycl::half, float, float>(
  2488. q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
  2489. beta, c, ldc, stride_c, batch_size);
  2490. break;
  2491. }
  2492. #endif
  2493. case detail::get_type_combination_id(
  2494. library_data_t::real_half, library_data_t::real_half,
  2495. library_data_t::real_half, library_data_t::real_float):
  2496. {
  2497. float alpha_value =
  2498. dpct::get_value(reinterpret_cast<const float *>(alpha), q);
  2499. float beta_value =
  2500. dpct::get_value(reinterpret_cast<const float *>(beta), q);
  2501. sycl::half alpha_half(alpha_value);
  2502. sycl::half beta_half(beta_value);
  2503. detail::gemm_batch_impl<sycl::half, sycl::half, sycl::half, sycl::half>(
  2504. q, a_trans, b_trans, m, n, k, &alpha_half, a, lda, stride_a, b, ldb, stride_b,
  2505. &beta_half, c, ldc, stride_c, batch_size);
  2506. break;
  2507. }
  2508. default:
  2509. throw std::runtime_error("the combination of data type is unsupported");
  2510. }
  2511. }
  2512. static inline void
  2513. async_dpct_memcpy(void *to_ptr, size_t to_pitch, const void *from_ptr,
  2514. size_t from_pitch, size_t x, size_t y,
  2515. memcpy_direction direction = automatic,
  2516. sycl::queue &q = get_default_queue())
  2517. {
  2518. detail::dpct_memcpy(q, to_ptr, from_ptr, to_pitch, from_pitch, x, y,
  2519. direction);
  2520. }
  2521. using err0 = detail::generic_error_type<struct err0_tag, int>;
  2522. using err1 = detail::generic_error_type<struct err1_tag, int>;
  2523. static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) {
  2524. detail::dpct_free(ptr, q);
  2525. }
  2526. /// dpct accessor used as device function parameter.
  2527. template <class T, memory_region Memory, size_t Dimension> class accessor;
  2528. template <class T, memory_region Memory> class accessor<T, Memory, 3> {
  2529. public:
  2530. using memory_t = detail::memory_traits<Memory, T>;
  2531. using element_t = typename memory_t::element_t;
  2532. using pointer_t = typename memory_t::pointer_t;
  2533. using accessor_t = typename memory_t::template accessor_t<3>;
  2534. accessor(pointer_t data, const sycl::range<3> &in_range)
  2535. : _data(data), _range(in_range) {}
  2536. template <memory_region M = Memory>
  2537. accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
  2538. : accessor(acc, acc.get_range()) {}
  2539. accessor(const accessor_t &acc, const sycl::range<3> &in_range)
  2540. : accessor(acc.get_pointer(), in_range) {}
  2541. accessor<T, Memory, 2> operator[](size_t index) const {
  2542. sycl::range<2> sub(_range.get(1), _range.get(2));
  2543. return accessor<T, Memory, 2>(_data + index * sub.size(), sub);
  2544. }
  2545. pointer_t get_ptr() const { return _data; }
  2546. private:
  2547. pointer_t _data;
  2548. sycl::range<3> _range;
  2549. };
  2550. template <class T, memory_region Memory> class accessor<T, Memory, 2> {
  2551. public:
  2552. using memory_t = detail::memory_traits<Memory, T>;
  2553. using element_t = typename memory_t::element_t;
  2554. using pointer_t = typename memory_t::pointer_t;
  2555. using accessor_t = typename memory_t::template accessor_t<2>;
  2556. accessor(pointer_t data, const sycl::range<2> &in_range)
  2557. : _data(data), _range(in_range) {}
  2558. template <memory_region M = Memory>
  2559. accessor(typename std::enable_if<M != local, const accessor_t>::type &acc)
  2560. : accessor(acc, acc.get_range()) {}
  2561. accessor(const accessor_t &acc, const sycl::range<2> &in_range)
  2562. : accessor(acc.get_pointer(), in_range) {}
  2563. pointer_t operator[](size_t index) const {
  2564. return _data + _range.get(1) * index;
  2565. }
  2566. pointer_t get_ptr() const { return _data; }
  2567. private:
  2568. pointer_t _data;
  2569. sycl::range<2> _range;
  2570. };
  2571. namespace detail {
  2572. /// Device variable with address space of shared, global or constant.
  2573. template <class T, memory_region Memory, size_t Dimension> class device_memory {
  2574. public:
  2575. using accessor_t =
  2576. typename detail::memory_traits<Memory,
  2577. T>::template accessor_t<Dimension>;
  2578. using value_t = typename detail::memory_traits<Memory, T>::value_t;
  2579. using dpct_accessor_t = dpct::accessor<T, Memory, Dimension>;
  2580. device_memory() : device_memory(sycl::range<Dimension>(1)) {}
  2581. /// Constructor of 1-D array with initializer list
  2582. device_memory(const sycl::range<Dimension> &in_range,
  2583. std::initializer_list<value_t> &&init_list)
  2584. : device_memory(in_range) {
  2585. assert(init_list.size() <= in_range.size());
  2586. _host_ptr = (value_t *)std::malloc(_size);
  2587. std::memset(_host_ptr, 0, _size);
  2588. std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T));
  2589. }
  2590. /// Constructor of 2-D array with initializer list
  2591. template <size_t D = Dimension>
  2592. device_memory(
  2593. const typename std::enable_if<D == 2, sycl::range<2>>::type &in_range,
  2594. std::initializer_list<std::initializer_list<value_t>> &&init_list)
  2595. : device_memory(in_range) {
  2596. assert(init_list.size() <= in_range[0]);
  2597. _host_ptr = (value_t *)std::malloc(_size);
  2598. std::memset(_host_ptr, 0, _size);
  2599. auto tmp_data = _host_ptr;
  2600. for (auto sub_list : init_list) {
  2601. assert(sub_list.size() <= in_range[1]);
  2602. std::memcpy(tmp_data, sub_list.begin(),
  2603. sub_list.size() * sizeof(T));
  2604. tmp_data += in_range[1];
  2605. }
  2606. }
  2607. /// Constructor with range
  2608. device_memory(const sycl::range<Dimension> &range_in)
  2609. : _size(range_in.size() * sizeof(T)), _range(range_in),
  2610. _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) {
  2611. static_assert(
  2612. (Memory == global) || (Memory == constant) || (Memory == shared),
  2613. "device memory region should be global, constant or shared");
  2614. // Make sure that singleton class mem_mgr and dev_mgr will destruct
  2615. // later than this.
  2616. detail::mem_mgr::instance();
  2617. dev_mgr::instance();
  2618. }
  2619. /// Constructor with range
  2620. template <class... Args>
  2621. device_memory(Args... Arguments)
  2622. : device_memory(sycl::range<Dimension>(Arguments...)) {}
  2623. ~device_memory() {
  2624. if (_device_ptr && !_reference)
  2625. dpct::dpct_free(_device_ptr);
  2626. if (_host_ptr)
  2627. std::free(_host_ptr);
  2628. }
  2629. /// Allocate memory with default queue, and init memory if has initial
  2630. /// value.
  2631. void init() { init(dpct::get_default_queue()); }
  2632. /// Allocate memory with specified queue, and init memory if has initial
  2633. /// value.
  2634. void init(sycl::queue &q) {
  2635. if (_device_ptr)
  2636. return;
  2637. if (!_size)
  2638. return;
  2639. allocate_device(q);
  2640. if (_host_ptr)
  2641. detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size,
  2642. host_to_device);
  2643. }
  2644. /// The variable is assigned to a device pointer.
  2645. void assign(value_t *src, size_t size) {
  2646. this->~device_memory();
  2647. new (this) device_memory(src, size);
  2648. }
  2649. /// Get memory pointer of the memory object, which is virtual pointer when
  2650. /// usm is not used, and device pointer when usm is used.
  2651. value_t *get_ptr() { return get_ptr(get_default_queue()); }
  2652. /// Get memory pointer of the memory object, which is virtual pointer when
  2653. /// usm is not used, and device pointer when usm is used.
  2654. value_t *get_ptr(sycl::queue &q) {
  2655. init(q);
  2656. return _device_ptr;
  2657. }
  2658. /// Get the device memory object size in bytes.
  2659. size_t get_size() { return _size; }
  2660. template <size_t D = Dimension>
  2661. typename std::enable_if<D == 1, T>::type &operator[](size_t index) {
  2662. init();
  2663. return _device_ptr[index];
  2664. }
  2665. /// Get dpct::accessor with dimension info for the device memory object
  2666. /// when usm is used and dimension is greater than 1.
  2667. template <size_t D = Dimension>
  2668. typename std::enable_if<D != 1, dpct_accessor_t>::type
  2669. get_access([[maybe_unused]] sycl::handler &cgh) {
  2670. return dpct_accessor_t((T *)_device_ptr, _range);
  2671. }
  2672. private:
  2673. device_memory(value_t *memory_ptr, size_t size)
  2674. : _size(size), _range(size / sizeof(T)), _reference(true),
  2675. _device_ptr(memory_ptr) {}
  2676. void allocate_device(sycl::queue &q) {
  2677. #ifndef DPCT_USM_LEVEL_NONE
  2678. if (Memory == shared) {
  2679. _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(),
  2680. q.get_context());
  2681. return;
  2682. }
  2683. #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY
  2684. if (Memory == constant) {
  2685. _device_ptr = (value_t *)sycl::malloc_device(
  2686. _size, q.get_device(), q.get_context(),
  2687. sycl::ext::oneapi::property::usm::device_read_only());
  2688. return;
  2689. }
  2690. #endif
  2691. #endif
  2692. _device_ptr = (value_t *)detail::dpct_malloc(_size, q);
  2693. }
  2694. size_t _size;
  2695. sycl::range<Dimension> _range;
  2696. bool _reference;
  2697. value_t *_host_ptr;
  2698. value_t *_device_ptr;
  2699. };
  2700. template <class T, memory_region Memory>
  2701. class device_memory<T, Memory, 0> : public device_memory<T, Memory, 1> {
  2702. public:
  2703. using base = device_memory<T, Memory, 1>;
  2704. using value_t = typename base::value_t;
  2705. using accessor_t =
  2706. typename detail::memory_traits<Memory, T>::template accessor_t<0>;
  2707. /// Constructor with initial value.
  2708. device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {}
  2709. /// Default constructor
  2710. device_memory() : base(1) {}
  2711. };
  2712. } // namespace detail
  2713. template <class T, size_t Dimension>
  2714. using global_memory = detail::device_memory<T, global, Dimension>;
  2715. template <class T, size_t Dimension>
  2716. using constant_memory = detail::device_memory<T, constant, Dimension>;
  2717. template <class T, size_t Dimension>
  2718. using shared_memory = detail::device_memory<T, shared, Dimension>;
  2719. template <typename T,
  2720. sycl::access::address_space addressSpace =
  2721. sycl::access::address_space::global_space,
  2722. sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
  2723. sycl::memory_scope memoryScope = sycl::memory_scope::device>
  2724. inline T atomic_fetch_add(T *addr, T operand) {
  2725. auto atm =
  2726. sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
  2727. return atm.fetch_add(operand);
  2728. }
  2729. template <sycl::access::address_space addressSpace =
  2730. sycl::access::address_space::global_space,
  2731. sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
  2732. sycl::memory_scope memoryScope = sycl::memory_scope::device,
  2733. typename T1, typename T2>
  2734. inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
  2735. auto atm =
  2736. sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
  2737. return atm.fetch_add(operand);
  2738. }
  2739. template <typename T, sycl::access::address_space addressSpace =
  2740. sycl::access::address_space::global_space>
  2741. inline T atomic_fetch_add(T *addr, T operand,
  2742. sycl::memory_order memoryOrder) {
  2743. switch (memoryOrder) {
  2744. case sycl::memory_order::relaxed:
  2745. return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
  2746. sycl::memory_scope::device>(addr, operand);
  2747. case sycl::memory_order::acq_rel:
  2748. return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
  2749. sycl::memory_scope::device>(addr, operand);
  2750. case sycl::memory_order::seq_cst:
  2751. return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
  2752. sycl::memory_scope::device>(addr, operand);
  2753. default:
  2754. assert(false && "Invalid memory_order for atomics. Valid memory_order for "
  2755. "atomics are: sycl::memory_order::relaxed, "
  2756. "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
  2757. }
  2758. }
  2759. template <sycl::access::address_space addressSpace =
  2760. sycl::access::address_space::global_space,
  2761. typename T1, typename T2>
  2762. inline T1 atomic_fetch_add(T1 *addr, T2 operand,
  2763. sycl::memory_order memoryOrder) {
  2764. atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
  2765. }
  2766. } // COPY from DPCT head files
  2767. #endif // GGML_SYCL_DPCT_HELPER_HPP