ggml-webgpu.cpp 143 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865
  1. /*
  2. WebGPU backend implementation.
  3. Note: Use ClangFormat to format this file.
  4. */
  5. #include "ggml-webgpu.h"
  6. #include "ggml-backend-impl.h"
  7. #include "ggml-impl.h"
  8. #include "ggml-wgsl-shaders.hpp"
  9. #ifdef __EMSCRIPTEN__
  10. # include <emscripten/emscripten.h>
  11. #endif
  12. #include <webgpu/webgpu_cpp.h>
  13. #include <atomic>
  14. #include <condition_variable>
  15. #include <cstring>
  16. #include <iostream>
  17. #include <map>
  18. #include <mutex>
  19. #include <optional>
  20. #include <string>
  21. #include <vector>
  22. #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1))
  23. #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N))
  24. #ifdef GGML_WEBGPU_DEBUG
  25. # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl
  26. # define WEBGPU_DEBUG_BUF_ELEMS 32
  27. #else
  28. # define WEBGPU_LOG_DEBUG(msg) ((void) 0)
  29. #endif // GGML_WEBGPU_DEBUG
  30. #ifdef GGML_WEBGPU_CPU_PROFILE
  31. // total timing (aggregated)
  32. # define WEBGPU_CPU_PROFILE_TOTAL_START(id) auto cpu_total_start_##id = std::chrono::high_resolution_clock::now();
  33. # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx) \
  34. auto cpu_total_end_##id = std::chrono::high_resolution_clock::now(); \
  35. double cpu_total_time_##id = \
  36. std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \
  37. (ctx)->cpu_time_ms[#id] += cpu_total_time_##id;
  38. // fine-grained timing (not included in totals)
  39. # define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now();
  40. # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx) \
  41. auto cpu_detail_end_##id = std::chrono::high_resolution_clock::now(); \
  42. double cpu_detail_time_##id = \
  43. std::chrono::duration<double, std::milli>(cpu_detail_end_##id - cpu_detail_start_##id).count(); \
  44. (ctx)->cpu_detail_ms[#id] += cpu_detail_time_##id;
  45. #else
  46. # define WEBGPU_CPU_PROFILE_TOTAL_START(id)
  47. # define WEBGPU_CPU_PROFILE_TOTAL_END(id, ctx)
  48. # define WEBGPU_CPU_PROFILE_DETAIL_START(id)
  49. # define WEBGPU_CPU_PROFILE_DETAIL_END(id, ctx)
  50. #endif // GGML_WEBGPU_CPU_PROFILE
  51. #ifdef GGML_WEBGPU_GPU_PROFILE
  52. # define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24
  53. # define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps
  54. #endif
  55. /* Constants */
  56. // Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed.
  57. #define WEBGPU_MAX_WG_SIZE 288
  58. #define WEBGPU_MUL_MAT_WG_SIZE 256
  59. #define WEBGPU_NUM_PARAM_BUFS 32u
  60. #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u
  61. #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0
  62. // Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool
  63. #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE
  64. #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters
  65. #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32
  66. #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4
  67. #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4
  68. // For operations which process a row in parallel, this seems like a reasonable default
  69. #define WEBGPU_ROW_SPLIT_WG_SIZE 64
  70. // Matrix multiplication parameters
  71. // Register tiling parameters
  72. #define WEBGPU_MUL_MAT_TILE_M 8
  73. #define WEBGPU_MUL_MAT_TILE_N 8
  74. #define WEBGPU_MUL_MAT_WG_SIZE_M 8
  75. #define WEBGPU_MUL_MAT_WG_SIZE_N 8
  76. #define WEBGPU_MUL_MAT_TILE_K 32
  77. // Subgroup matrix parameters
  78. // The number of subgroups in the M dimension
  79. #define WEBGPU_MUL_MAT_SUBGROUP_M 2
  80. // The number of subgroups in the N dimension
  81. #define WEBGPU_MUL_MAT_SUBGROUP_N 2
  82. // The number of subgroup matrices each subgroup accumulates over
  83. #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4
  84. #define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2
  85. // Matrix-vector multiplication parameters
  86. #define WEBGPU_MUL_MAT_VEC_WG_SIZE 256
  87. // Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size
  88. #define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64
  89. #define WEBGPU_MUL_MAT_VEC_TILE_K 256
  90. /* End Constants */
  91. // This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations.
  92. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT
  93. // Always returns the base offset of a tensor, regardless of views.
  94. static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) {
  95. if (tensor->view_src) {
  96. return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base;
  97. }
  98. return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base;
  99. }
  100. /* Struct definitions */
  101. // Forward reference
  102. static void ggml_webgpu_create_buffer(wgpu::Device & device,
  103. wgpu::Buffer & buffer,
  104. size_t size,
  105. wgpu::BufferUsage usage,
  106. const char * label);
  107. struct webgpu_pool_bufs {
  108. wgpu::Buffer host_buf;
  109. wgpu::Buffer dev_buf;
  110. };
  111. // The futures to wait on for a single queue submission
  112. struct webgpu_submission_futures {
  113. std::vector<wgpu::FutureWaitInfo> futures;
  114. };
  115. // Holds a pool of parameter buffers for WebGPU operations
  116. struct webgpu_buf_pool {
  117. std::vector<webgpu_pool_bufs> free;
  118. std::mutex mutex;
  119. std::condition_variable cv;
  120. void init(wgpu::Device device,
  121. int num_bufs,
  122. size_t buf_size,
  123. wgpu::BufferUsage dev_buf_usage,
  124. wgpu::BufferUsage host_buf_usage) {
  125. for (int i = 0; i < num_bufs; i++) {
  126. wgpu::Buffer host_buf;
  127. wgpu::Buffer dev_buf;
  128. ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf");
  129. ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf");
  130. free.push_back({ host_buf, dev_buf });
  131. }
  132. }
  133. webgpu_pool_bufs alloc_bufs() {
  134. std::unique_lock<std::mutex> lock(mutex);
  135. cv.wait(lock, [this] { return !free.empty(); });
  136. webgpu_pool_bufs bufs = free.back();
  137. free.pop_back();
  138. return bufs;
  139. }
  140. void free_bufs(std::vector<webgpu_pool_bufs> bufs) {
  141. std::lock_guard<std::mutex> lock(mutex);
  142. free.insert(free.end(), bufs.begin(), bufs.end());
  143. cv.notify_all();
  144. }
  145. void cleanup() {
  146. std::lock_guard<std::mutex> lock(mutex);
  147. for (auto & bufs : free) {
  148. bufs.host_buf.Destroy();
  149. bufs.dev_buf.Destroy();
  150. }
  151. free.clear();
  152. }
  153. };
  154. #ifdef GGML_WEBGPU_GPU_PROFILE
  155. struct webgpu_gpu_profile_bufs {
  156. wgpu::Buffer host_buf;
  157. wgpu::Buffer dev_buf;
  158. wgpu::QuerySet query_set;
  159. };
  160. // Holds a pool of parameter buffers for WebGPU operations
  161. struct webgpu_gpu_profile_buf_pool {
  162. std::vector<webgpu_gpu_profile_bufs> free;
  163. std::mutex mutex;
  164. std::condition_variable cv;
  165. void init(wgpu::Device device,
  166. int num_bufs,
  167. size_t buf_size,
  168. wgpu::BufferUsage dev_buf_usage,
  169. wgpu::BufferUsage host_buf_usage) {
  170. for (int i = 0; i < num_bufs; i++) {
  171. wgpu::Buffer host_buf;
  172. wgpu::Buffer dev_buf;
  173. ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf");
  174. ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf");
  175. // Create a query set for 2 timestamps
  176. wgpu::QuerySetDescriptor ts_query_set_desc = {};
  177. ts_query_set_desc.type = wgpu::QueryType::Timestamp;
  178. ts_query_set_desc.count = 2;
  179. wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc);
  180. free.push_back({ host_buf, dev_buf, ts_query_set });
  181. }
  182. }
  183. webgpu_gpu_profile_bufs alloc_bufs() {
  184. std::unique_lock<std::mutex> lock(mutex);
  185. cv.wait(lock, [this] { return !free.empty(); });
  186. webgpu_gpu_profile_bufs bufs = free.back();
  187. free.pop_back();
  188. return bufs;
  189. }
  190. void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) {
  191. std::lock_guard<std::mutex> lock(mutex);
  192. free.insert(free.end(), bufs.begin(), bufs.end());
  193. cv.notify_all();
  194. }
  195. void cleanup() {
  196. std::lock_guard<std::mutex> lock(mutex);
  197. for (auto & bufs : free) {
  198. bufs.host_buf.Destroy();
  199. bufs.dev_buf.Destroy();
  200. bufs.query_set.Destroy();
  201. }
  202. free.clear();
  203. }
  204. };
  205. #endif
  206. struct webgpu_pipeline {
  207. wgpu::ComputePipeline pipeline;
  208. std::string name;
  209. };
  210. struct webgpu_command {
  211. wgpu::CommandBuffer commands;
  212. webgpu_pool_bufs params_bufs;
  213. std::optional<webgpu_pool_bufs> set_rows_error_bufs;
  214. #ifdef GGML_WEBGPU_GPU_PROFILE
  215. webgpu_gpu_profile_bufs timestamp_query_bufs;
  216. std::string pipeline_name;
  217. #endif
  218. };
  219. // All the base objects needed to run operations on a WebGPU device
  220. struct webgpu_context_struct {
  221. wgpu::Instance instance;
  222. wgpu::Adapter adapter;
  223. wgpu::Device device;
  224. wgpu::Queue queue;
  225. wgpu::Limits limits;
  226. uint32_t subgroup_size;
  227. #ifndef __EMSCRIPTEN__
  228. bool supports_subgroup_matrix = false;
  229. wgpu::SubgroupMatrixConfig subgroup_matrix_config;
  230. #endif
  231. std::recursive_mutex mutex;
  232. std::atomic_uint inflight_threads = 0;
  233. webgpu_buf_pool param_buf_pool;
  234. webgpu_buf_pool set_rows_error_buf_pool;
  235. std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index
  236. std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized
  237. std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
  238. mul_mat_vec_pipelines; // src0_type, src1_type, vectorized
  239. std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized
  240. std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
  241. std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
  242. std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
  243. std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
  244. std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
  245. std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
  246. std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
  247. std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
  248. std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split
  249. std::map<int, webgpu_pipeline> scale_pipelines; // inplace
  250. std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace
  251. std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> unary_pipelines; // unary_op, type, inplace
  252. size_t memset_bytes_per_thread;
  253. // Staging buffer for reading data from the GPU
  254. wgpu::Buffer get_tensor_staging_buf;
  255. #ifdef GGML_WEBGPU_DEBUG
  256. wgpu::Buffer debug_host_buf;
  257. wgpu::Buffer debug_dev_buf;
  258. #endif
  259. #ifdef GGML_WEBGPU_CPU_PROFILE
  260. // Profiling: labeled CPU time in ms (total)
  261. std::unordered_map<std::string, double> cpu_time_ms;
  262. // Profiling: detailed CPU time in ms
  263. std::unordered_map<std::string, double> cpu_detail_ms;
  264. #endif
  265. #ifdef GGML_WEBGPU_GPU_PROFILE
  266. // Profiling: per-shader GPU time in ms
  267. std::unordered_map<std::string, double> shader_gpu_time_ms;
  268. // Profiling: pool of timestamp query buffers (one per operation)
  269. webgpu_gpu_profile_buf_pool timestamp_query_buf_pool;
  270. #endif
  271. };
  272. typedef std::shared_ptr<webgpu_context_struct> webgpu_context;
  273. struct ggml_backend_webgpu_reg_context {
  274. webgpu_context webgpu_ctx;
  275. size_t device_count;
  276. const char * name;
  277. };
  278. struct ggml_backend_webgpu_device_context {
  279. webgpu_context webgpu_ctx;
  280. std::string device_name;
  281. std::string device_desc;
  282. };
  283. struct ggml_backend_webgpu_context {
  284. webgpu_context webgpu_ctx;
  285. std::string name;
  286. };
  287. struct ggml_backend_webgpu_buffer_context {
  288. webgpu_context webgpu_ctx;
  289. wgpu::Buffer buffer;
  290. std::string label;
  291. ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) :
  292. webgpu_ctx(std::move(ctx)),
  293. buffer(std::move(buf)),
  294. label(std::move(lbl)) {}
  295. };
  296. /* End struct definitions */
  297. /* WebGPU object initializations */
  298. // Process a WGSL shader string, replacing tokens of the form {{KEY}} with
  299. // the corresponding values provided in `repls`.
  300. static std::string ggml_webgpu_process_shader_repls(const char * src,
  301. const std::map<std::string, std::string> & repls) {
  302. if (!src) {
  303. return std::string();
  304. }
  305. std::string s = src;
  306. for (const auto & kv : repls) {
  307. std::string token = "{{" + kv.first + "}}";
  308. size_t pos = 0;
  309. while ((pos = s.find(token, pos)) != std::string::npos) {
  310. s.replace(pos, token.length(), kv.second);
  311. pos += kv.second.length();
  312. }
  313. }
  314. return s;
  315. }
  316. static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device,
  317. const char * shader_code,
  318. const char * label,
  319. const std::vector<wgpu::ConstantEntry> & constants = {}) {
  320. wgpu::ShaderSourceWGSL shader_source;
  321. shader_source.code = shader_code;
  322. wgpu::ShaderModuleDescriptor shader_desc;
  323. shader_desc.nextInChain = &shader_source;
  324. wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc);
  325. wgpu::ComputePipelineDescriptor pipeline_desc;
  326. pipeline_desc.label = label;
  327. pipeline_desc.compute.module = shader_module;
  328. pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code
  329. pipeline_desc.layout = nullptr; // nullptr means auto layout
  330. if (constants.size() > 0) {
  331. pipeline_desc.compute.constants = constants.data();
  332. pipeline_desc.compute.constantCount = constants.size();
  333. }
  334. return { device.CreateComputePipeline(&pipeline_desc), label };
  335. }
  336. static void ggml_webgpu_create_buffer(wgpu::Device & device,
  337. wgpu::Buffer & buffer,
  338. size_t size,
  339. wgpu::BufferUsage usage,
  340. const char * label) {
  341. wgpu::BufferDescriptor buffer_desc;
  342. buffer_desc.size = size;
  343. buffer_desc.usage = usage;
  344. buffer_desc.label = label;
  345. buffer_desc.mappedAtCreation = false;
  346. // TODO: error handling
  347. buffer = device.CreateBuffer(&buffer_desc);
  348. }
  349. /** End WebGPU object initializations */
  350. /** WebGPU Actions */
  351. // Wait for the queue to finish processing all submitted work
  352. static void ggml_backend_webgpu_wait(webgpu_context & ctx,
  353. std::vector<webgpu_submission_futures> & futures,
  354. bool block = true) {
  355. // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads,
  356. // inflight_max may be 0, meaning that we must wait on all futures.
  357. uint64_t timeout_ms = block ? UINT64_MAX : 0;
  358. uint32_t inflight_threads = ctx->inflight_threads;
  359. uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u);
  360. while (futures.size() >= inflight_max && futures.size() > 0) {
  361. ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
  362. futures.erase(futures.begin());
  363. }
  364. size_t i = 0;
  365. while (i < futures.size()) {
  366. auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
  367. switch (waitStatus) {
  368. case wgpu::WaitStatus::Success:
  369. futures.erase(futures.begin() + i);
  370. break;
  371. case wgpu::WaitStatus::TimedOut:
  372. i++;
  373. break;
  374. case wgpu::WaitStatus::Error:
  375. GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
  376. break;
  377. default:
  378. GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
  379. break;
  380. }
  381. }
  382. }
  383. static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx,
  384. wgpu::Buffer & buffer,
  385. wgpu::MapMode mode,
  386. size_t offset,
  387. size_t size) {
  388. ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous,
  389. [](wgpu::MapAsyncStatus status, wgpu::StringView message) {
  390. if (status != wgpu::MapAsyncStatus::Success) {
  391. GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n",
  392. message.data);
  393. }
  394. }),
  395. UINT64_MAX);
  396. }
  397. #ifdef GGML_WEBGPU_DEBUG
  398. // This function adds debugging information to shaders, as WebGPU does not support printing directly.
  399. // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and
  400. // debug statements in the shader, and then call this function after encoding the commands and submitting them.
  401. static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
  402. wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
  403. encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
  404. wgpu::CommandBuffer commands = encoder.Finish();
  405. ctx->queue.Submit(1, &commands);
  406. ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
  407. const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
  408. std::cout << "debug data:";
  409. for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
  410. std::cout << " " << i << ": " << debug_data[i];
  411. }
  412. std::cout << "\n";
  413. ctx->debug_host_buf.Unmap();
  414. }
  415. #endif
  416. static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector<webgpu_command> commands) {
  417. std::vector<wgpu::CommandBuffer> command_buffers;
  418. std::vector<webgpu_pool_bufs> params_bufs;
  419. std::vector<webgpu_pool_bufs> set_rows_error_bufs;
  420. #ifdef GGML_WEBGPU_GPU_PROFILE
  421. std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs;
  422. #endif
  423. for (const auto & command : commands) {
  424. command_buffers.push_back(command.commands);
  425. params_bufs.push_back(command.params_bufs);
  426. if (command.set_rows_error_bufs) {
  427. set_rows_error_bufs.push_back(command.set_rows_error_bufs.value());
  428. }
  429. }
  430. ctx->queue.Submit(command_buffers.size(), command_buffers.data());
  431. std::vector<wgpu::FutureWaitInfo> futures;
  432. wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone(
  433. wgpu::CallbackMode::AllowSpontaneous,
  434. [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
  435. if (status != wgpu::QueueWorkDoneStatus::Success) {
  436. GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str());
  437. }
  438. // Free the staged buffers
  439. ctx->param_buf_pool.free_bufs({ params_bufs });
  440. });
  441. futures.push_back({ p_f });
  442. for (const auto & bufs : set_rows_error_bufs) {
  443. wgpu::Future f = bufs.host_buf.MapAsync(
  444. wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
  445. [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) {
  446. if (status != wgpu::MapAsyncStatus::Success) {
  447. GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str());
  448. } else {
  449. const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange();
  450. if (*error_data) {
  451. GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported.");
  452. }
  453. // We can't unmap in here due to WebGPU reentrancy limitations.
  454. ctx->set_rows_error_buf_pool.free_bufs({ bufs });
  455. }
  456. });
  457. futures.push_back({ f });
  458. }
  459. #ifdef GGML_WEBGPU_GPU_PROFILE
  460. for (const auto & command : commands) {
  461. auto label = command.pipeline_name;
  462. auto ts_bufs = command.timestamp_query_bufs;
  463. wgpu::Future f = ts_bufs.host_buf.MapAsync(
  464. wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous,
  465. [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) {
  466. if (status != wgpu::MapAsyncStatus::Success) {
  467. GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str());
  468. } else {
  469. const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange();
  470. // WebGPU timestamps are in ns; convert to ms
  471. double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6;
  472. ctx->shader_gpu_time_ms[label] += elapsed_ms;
  473. // We can't unmap in here due to WebGPU reentrancy limitations.
  474. ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs });
  475. }
  476. });
  477. futures.push_back({ f });
  478. }
  479. #endif
  480. return { futures };
  481. }
  482. static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx,
  483. webgpu_pipeline & pipeline,
  484. std::vector<uint32_t> params,
  485. std::vector<wgpu::BindGroupEntry> bind_group_entries,
  486. uint32_t wg_x,
  487. uint32_t wg_y = 1,
  488. std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) {
  489. webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs();
  490. ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize());
  491. uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange();
  492. for (size_t i = 0; i < params.size(); i++) {
  493. _params[i] = params[i];
  494. };
  495. params_bufs.host_buf.Unmap();
  496. uint32_t params_bufs_binding_num = bind_group_entries.size();
  497. bind_group_entries.push_back({ .binding = params_bufs_binding_num,
  498. .buffer = params_bufs.dev_buf,
  499. .offset = 0,
  500. .size = params_bufs.dev_buf.GetSize() });
  501. wgpu::BindGroupDescriptor bind_group_desc;
  502. bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0);
  503. bind_group_desc.entryCount = bind_group_entries.size();
  504. bind_group_desc.entries = bind_group_entries.data();
  505. bind_group_desc.label = pipeline.name.c_str();
  506. wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc);
  507. wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder();
  508. encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize());
  509. #ifdef GGML_WEBGPU_GPU_PROFILE
  510. // --- Profiling: GPU timestamp queries ---
  511. // Allocate a timestamp query buffer (2 timestamps: start/end)
  512. webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
  513. if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
  514. ts_bufs.host_buf.Unmap();
  515. }
  516. wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set,
  517. .beginningOfPassWriteIndex = 0,
  518. .endOfPassWriteIndex = 1 };
  519. wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes };
  520. wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc);
  521. #else
  522. wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
  523. #endif
  524. pass.SetPipeline(pipeline.pipeline);
  525. pass.SetBindGroup(0, bind_group);
  526. pass.DispatchWorkgroups(wg_x, wg_y, 1);
  527. pass.End();
  528. #ifdef GGML_WEBGPU_GPU_PROFILE
  529. // Resolve the query set into the device buffer
  530. encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0);
  531. encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize());
  532. #endif
  533. // If there are SET_ROWS operations in this submission, copy their error buffers to the host.
  534. if (set_rows_error_bufs) {
  535. encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0,
  536. set_rows_error_bufs->host_buf.GetSize());
  537. }
  538. wgpu::CommandBuffer commands = encoder.Finish();
  539. webgpu_command result = {};
  540. result.commands = commands;
  541. result.params_bufs = params_bufs;
  542. result.set_rows_error_bufs = set_rows_error_bufs;
  543. #ifdef GGML_WEBGPU_GPU_PROFILE
  544. result.timestamp_query_bufs = ts_bufs;
  545. result.pipeline_name = pipeline.name;
  546. #endif
  547. return result;
  548. }
  549. static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx,
  550. wgpu::Buffer & buf,
  551. uint32_t value,
  552. size_t offset,
  553. size_t size) {
  554. std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value };
  555. std::vector<wgpu::BindGroupEntry> entries = {
  556. { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() }
  557. };
  558. size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread;
  559. uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg);
  560. webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x);
  561. std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command }) };
  562. ggml_backend_webgpu_wait(ctx, futures);
  563. }
  564. /** End WebGPU Actions */
  565. /** GGML Backend Interface */
  566. static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
  567. ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
  568. return ctx->name.c_str();
  569. }
  570. static void ggml_backend_webgpu_free(ggml_backend_t backend) {
  571. ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
  572. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
  573. #ifdef GGML_WEBGPU_CPU_PROFILE
  574. std::cout << "\n[ggml_webgpu cpu profiling summary]\n";
  575. double total_cpu = 0.0;
  576. for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
  577. total_cpu += kv.second;
  578. }
  579. std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n";
  580. std::cout << "ggml_webgpu: cpu breakdown:\n";
  581. for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) {
  582. double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
  583. std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
  584. }
  585. if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) {
  586. std::cout << "ggml_webgpu: cpu detailed breakdown:\n";
  587. }
  588. for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) {
  589. double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0;
  590. std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
  591. }
  592. #endif
  593. #ifdef GGML_WEBGPU_GPU_PROFILE
  594. std::cout << "\n[ggml_webgpu gpu profiling summary]\n";
  595. double total_gpu = 0.0;
  596. for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
  597. total_gpu += kv.second;
  598. }
  599. std::cout << "ggml_webgpu: total gpu time (all shaders): " << total_gpu << " ms\n";
  600. std::cout << "\nggml_webgpu: gpu breakdown:\n";
  601. for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) {
  602. double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0;
  603. std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n";
  604. }
  605. #endif
  606. #if defined(GGML_WEBGPU_CPU_PROFILE) && defined(GGML_WEBGPU_GPU_PROFILE)
  607. std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
  608. #endif
  609. #if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE)
  610. GGML_UNUSED(ctx);
  611. #endif
  612. }
  613. static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
  614. return webgpu_tensor_offset(tensor) + tensor->view_offs;
  615. }
  616. static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
  617. ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context;
  618. return ctx->buffer;
  619. }
  620. static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
  621. size_t offset = ggml_webgpu_tensor_offset(t);
  622. return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
  623. }
  624. static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
  625. size_t offset = ggml_webgpu_tensor_offset(t);
  626. return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
  627. }
  628. static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) {
  629. return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT);
  630. }
  631. // Used to determine if two tensors are the same for in-place operations
  632. static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
  633. return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
  634. (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
  635. }
  636. static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
  637. uint32_t ne = (uint32_t) ggml_nelements(dst);
  638. std::vector<uint32_t> params = {
  639. ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
  640. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  641. // Convert byte-strides to element-strides
  642. (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
  643. (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
  644. (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
  645. (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
  646. // Logical shapes
  647. (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
  648. (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
  649. };
  650. std::vector<wgpu::BindGroupEntry> entries = {
  651. { .binding = 0,
  652. .buffer = ggml_webgpu_tensor_buf(src),
  653. .offset = ggml_webgpu_tensor_align_offset(ctx, src),
  654. .size = ggml_webgpu_tensor_binding_size(ctx, src) },
  655. { .binding = 1,
  656. .buffer = ggml_webgpu_tensor_buf(dst),
  657. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  658. .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
  659. };
  660. uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
  661. return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x);
  662. }
  663. static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
  664. ggml_tensor * src,
  665. ggml_tensor * idx,
  666. ggml_tensor * dst) {
  667. // For set rows specifically, we need to check if src and idx are empty tensors.
  668. if (ggml_is_empty(src) || ggml_is_empty(idx)) {
  669. return std::nullopt;
  670. }
  671. webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs();
  672. if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
  673. error_bufs.host_buf.Unmap();
  674. }
  675. std::vector<uint32_t> params = {
  676. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
  677. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
  678. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  679. // Convert byte-strides to element-strides
  680. (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
  681. (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
  682. (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
  683. (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
  684. (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
  685. // Shape of src
  686. (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3],
  687. // Shape of idx
  688. (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
  689. };
  690. std::vector<wgpu::BindGroupEntry> entries = {
  691. { .binding = 0,
  692. .buffer = ggml_webgpu_tensor_buf(src),
  693. .offset = ggml_webgpu_tensor_align_offset(ctx, src),
  694. .size = ggml_webgpu_tensor_binding_size(ctx, src) },
  695. { .binding = 1,
  696. .buffer = ggml_webgpu_tensor_buf(idx),
  697. .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
  698. .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
  699. { .binding = 2,
  700. .buffer = ggml_webgpu_tensor_buf(dst),
  701. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  702. .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
  703. { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() }
  704. };
  705. int vectorized = src->ne[0] % 4 == 0;
  706. webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized];
  707. uint32_t threads;
  708. if (vectorized) {
  709. threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4);
  710. } else {
  711. threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
  712. }
  713. uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE);
  714. return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs);
  715. }
  716. static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx,
  717. ggml_tensor * src,
  718. ggml_tensor * idx,
  719. ggml_tensor * dst) {
  720. std::vector<uint32_t> params = {
  721. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
  722. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)),
  723. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  724. // Convert byte-strides to element-strides
  725. (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
  726. (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)),
  727. (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)),
  728. (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
  729. (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
  730. // Shape of dst
  731. (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3],
  732. // Shape of idx
  733. (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2])
  734. };
  735. std::vector<wgpu::BindGroupEntry> entries = {
  736. { .binding = 0,
  737. .buffer = ggml_webgpu_tensor_buf(src),
  738. .offset = ggml_webgpu_tensor_align_offset(ctx, src),
  739. .size = ggml_webgpu_tensor_binding_size(ctx, src) },
  740. { .binding = 1,
  741. .buffer = ggml_webgpu_tensor_buf(idx),
  742. .offset = ggml_webgpu_tensor_align_offset(ctx, idx),
  743. .size = ggml_webgpu_tensor_binding_size(ctx, idx) },
  744. { .binding = 2,
  745. .buffer = ggml_webgpu_tensor_buf(dst),
  746. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  747. .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
  748. };
  749. uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE);
  750. uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0;
  751. webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized];
  752. return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
  753. }
  754. static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
  755. ggml_tensor * src0,
  756. ggml_tensor * src1,
  757. ggml_tensor * dst) {
  758. std::vector<uint32_t> params = {
  759. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
  760. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
  761. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  762. (uint32_t) dst->ne[0], // number of rows in result (M, transposed)
  763. (uint32_t) dst->ne[1], // number of columns in result (N)
  764. (uint32_t) src0->ne[0], // number of columns in src0/src1 (K)
  765. (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1
  766. (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1
  767. (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2
  768. (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2
  769. (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3
  770. (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3
  771. (uint32_t) src0->ne[2], // batch size in dimension 2
  772. (uint32_t) src0->ne[3], // batch size in dimension 3
  773. (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2
  774. (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3
  775. };
  776. std::vector<wgpu::BindGroupEntry> entries = {
  777. { .binding = 0,
  778. .buffer = ggml_webgpu_tensor_buf(src0),
  779. .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
  780. .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
  781. { .binding = 1,
  782. .buffer = ggml_webgpu_tensor_buf(src1),
  783. .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
  784. .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
  785. { .binding = 2,
  786. .buffer = ggml_webgpu_tensor_buf(dst),
  787. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  788. .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
  789. };
  790. webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0];
  791. uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE);
  792. uint32_t wg_y = 1;
  793. bool use_fast = false;
  794. switch (src1->type) {
  795. case GGML_TYPE_F16:
  796. use_fast = (src0->type == GGML_TYPE_F16);
  797. break;
  798. case GGML_TYPE_F32:
  799. switch (src0->type) {
  800. case GGML_TYPE_F32:
  801. case GGML_TYPE_F16:
  802. case GGML_TYPE_Q4_0:
  803. use_fast = true;
  804. break;
  805. default:
  806. break;
  807. }
  808. break;
  809. default:
  810. break;
  811. }
  812. if (use_fast) {
  813. int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0;
  814. if (dst->ne[1] == 1) {
  815. // We don't support vectorized mul_mat_vec for quantized types
  816. vectorized = vectorized && (src0->type < 2);
  817. pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized];
  818. uint32_t batches = dst->ne[2] * dst->ne[3];
  819. uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG);
  820. uint32_t total_wg = output_groups * batches;
  821. wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension;
  822. wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension);
  823. } else {
  824. pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized];
  825. uint32_t wg_m;
  826. uint32_t wg_n;
  827. #ifndef __EMSCRIPTEN__
  828. if (ctx->supports_subgroup_matrix) {
  829. // The total number of subgroups/workgroups needed per matrix.
  830. uint32_t wg_m_sg_tile =
  831. WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
  832. wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
  833. uint32_t wg_n_sg_tile =
  834. WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
  835. wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
  836. } else {
  837. #endif
  838. uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
  839. uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N;
  840. wg_m = CEIL_DIV(dst->ne[0], tile_m_s);
  841. wg_n = CEIL_DIV(dst->ne[1], tile_n_s);
  842. #ifndef __EMSCRIPTEN__
  843. }
  844. #endif
  845. wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3];
  846. }
  847. }
  848. return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
  849. }
  850. static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
  851. uint32_t ne = (uint32_t) ggml_nelements(dst);
  852. ggml_unary_op unary_op = ggml_get_unary_op(dst);
  853. uint32_t inplace = ggml_webgpu_tensor_equal(src, dst);
  854. std::vector<uint32_t> params = {
  855. ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
  856. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  857. // Convert byte-strides to element-strides
  858. (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
  859. (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
  860. (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
  861. (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
  862. // Logical shapes
  863. (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0],
  864. (uint32_t) dst->ne[1], (uint32_t) dst->ne[2]
  865. };
  866. switch (unary_op) {
  867. case GGML_UNARY_OP_XIELU:
  868. {
  869. // Get float parameters and reinterpret their bit patterns as uint32_t
  870. // for passing through the params buffer
  871. float alpha_n = ggml_get_op_params_f32(dst, 1);
  872. float alpha_p = ggml_get_op_params_f32(dst, 2);
  873. float beta = ggml_get_op_params_f32(dst, 3);
  874. float eps = ggml_get_op_params_f32(dst, 4);
  875. params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n));
  876. params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p));
  877. params.push_back(*reinterpret_cast<const uint32_t *>(&beta));
  878. params.push_back(*reinterpret_cast<const uint32_t *>(&eps));
  879. break;
  880. }
  881. default:
  882. break;
  883. }
  884. std::vector<wgpu::BindGroupEntry> entries = {
  885. { .binding = 0,
  886. .buffer = ggml_webgpu_tensor_buf(src),
  887. .offset = ggml_webgpu_tensor_align_offset(ctx, src),
  888. .size = ggml_webgpu_tensor_binding_size(ctx, src) },
  889. };
  890. if (!inplace) {
  891. entries.push_back({ .binding = 1,
  892. .buffer = ggml_webgpu_tensor_buf(dst),
  893. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  894. .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
  895. }
  896. uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE);
  897. return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x);
  898. }
  899. static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
  900. ggml_tensor * src0,
  901. ggml_tensor * src1,
  902. ggml_tensor * dst,
  903. webgpu_pipeline & pipeline,
  904. bool inplace) {
  905. std::vector<uint32_t> params = {
  906. (uint32_t) ggml_nelements(dst),
  907. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
  908. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
  909. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  910. (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
  911. (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
  912. (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
  913. (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
  914. (uint32_t) src0->ne[0],
  915. (uint32_t) src0->ne[1],
  916. (uint32_t) src0->ne[2],
  917. (uint32_t) src1->ne[0],
  918. (uint32_t) src1->ne[1],
  919. (uint32_t) src1->ne[2],
  920. (uint32_t) src1->ne[3],
  921. };
  922. std::vector<wgpu::BindGroupEntry> entries = {
  923. { .binding = 0,
  924. .buffer = ggml_webgpu_tensor_buf(src0),
  925. .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
  926. .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
  927. { .binding = 1,
  928. .buffer = ggml_webgpu_tensor_buf(src1),
  929. .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
  930. .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
  931. };
  932. if (!inplace) {
  933. entries.push_back({ .binding = 2,
  934. .buffer = ggml_webgpu_tensor_buf(dst),
  935. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  936. .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
  937. }
  938. uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
  939. return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
  940. }
  941. static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
  942. int inplace = ggml_webgpu_tensor_equal(src, dst);
  943. std::vector<uint32_t> params = {
  944. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
  945. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  946. (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
  947. (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
  948. (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
  949. (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
  950. (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
  951. (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
  952. (uint32_t) src->ne[0],
  953. (uint32_t) src->ne[1],
  954. (uint32_t) src->ne[2],
  955. (uint32_t) src->ne[3],
  956. *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader
  957. };
  958. std::vector<wgpu::BindGroupEntry> entries = {
  959. { .binding = 0,
  960. .buffer = ggml_webgpu_tensor_buf(src),
  961. .offset = ggml_webgpu_tensor_align_offset(ctx, src),
  962. .size = ggml_webgpu_tensor_binding_size(ctx, src) }
  963. };
  964. if (!inplace) {
  965. entries.push_back({ .binding = 1,
  966. .buffer = ggml_webgpu_tensor_buf(dst),
  967. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  968. .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
  969. }
  970. return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src));
  971. }
  972. static webgpu_command ggml_webgpu_rope(webgpu_context & ctx,
  973. ggml_tensor * src0,
  974. ggml_tensor * src1,
  975. ggml_tensor * src2,
  976. ggml_tensor * dst) {
  977. const int inplace = ggml_webgpu_tensor_equal(src0, dst);
  978. const int has_freq_factor = (src2 != nullptr);
  979. const int n_dims = ((int32_t *) dst->op_params)[1];
  980. const int mode = ((int32_t *) dst->op_params)[2];
  981. const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
  982. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  983. memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
  984. memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
  985. memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
  986. memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
  987. memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
  988. memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
  989. int sections[4];
  990. memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int));
  991. float theta_scale = powf(freq_base, -2.0f / n_dims);
  992. float corr_dims[2];
  993. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
  994. std::vector<uint32_t> params = {
  995. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
  996. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
  997. src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
  998. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  999. (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
  1000. (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
  1001. (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
  1002. (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
  1003. (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
  1004. (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
  1005. (uint32_t) ggml_nelements(src0) / 2,
  1006. (uint32_t) src0->ne[0],
  1007. (uint32_t) src0->ne[1],
  1008. (uint32_t) src0->ne[2],
  1009. (uint32_t) n_dims,
  1010. (uint32_t) mode,
  1011. *(uint32_t *) &theta_scale,
  1012. *(uint32_t *) &attn_factor,
  1013. *(uint32_t *) &freq_scale,
  1014. *(uint32_t *) &ext_factor,
  1015. *(uint32_t *) &corr_dims[0],
  1016. *(uint32_t *) &corr_dims[1],
  1017. (uint32_t) sections[0],
  1018. (uint32_t) sections[1],
  1019. (uint32_t) sections[2],
  1020. (uint32_t) sections[3]
  1021. };
  1022. std::vector<wgpu::BindGroupEntry> entries = {
  1023. { .binding = 0,
  1024. .buffer = ggml_webgpu_tensor_buf(src0),
  1025. .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
  1026. .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
  1027. { .binding = 1,
  1028. .buffer = ggml_webgpu_tensor_buf(src1),
  1029. .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
  1030. .size = ggml_webgpu_tensor_binding_size(ctx, src1) }
  1031. };
  1032. uint32_t dst_binding = 2;
  1033. if (has_freq_factor) {
  1034. dst_binding = 3;
  1035. entries.push_back({ .binding = 2,
  1036. .buffer = ggml_webgpu_tensor_buf(src2),
  1037. .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
  1038. .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
  1039. }
  1040. if (!inplace) {
  1041. entries.push_back({ .binding = dst_binding,
  1042. .buffer = ggml_webgpu_tensor_buf(dst),
  1043. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  1044. .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
  1045. }
  1046. webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace];
  1047. uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
  1048. return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
  1049. }
  1050. static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
  1051. const int split = (src1 != nullptr);
  1052. std::vector<uint32_t> params = {
  1053. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
  1054. src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
  1055. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  1056. (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
  1057. (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
  1058. (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
  1059. src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) :
  1060. (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
  1061. src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) :
  1062. (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
  1063. src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) :
  1064. (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
  1065. (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
  1066. (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
  1067. (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
  1068. (uint32_t) ggml_nelements(dst),
  1069. (uint32_t) dst->ne[0],
  1070. (uint32_t) dst->ne[1],
  1071. (uint32_t) dst->ne[2],
  1072. (uint32_t) ((int32_t *) dst->op_params)[1], // swapped
  1073. *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai
  1074. *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai
  1075. };
  1076. std::vector<wgpu::BindGroupEntry> entries = {
  1077. { .binding = 0,
  1078. .buffer = ggml_webgpu_tensor_buf(src0),
  1079. .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
  1080. .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
  1081. };
  1082. uint32_t dst_binding = 1;
  1083. if (split) {
  1084. dst_binding = 2;
  1085. entries.push_back({ .binding = 1,
  1086. .buffer = ggml_webgpu_tensor_buf(src1),
  1087. .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
  1088. .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
  1089. }
  1090. entries.push_back({ .binding = dst_binding,
  1091. .buffer = ggml_webgpu_tensor_buf(dst),
  1092. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  1093. .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
  1094. webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split];
  1095. uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
  1096. return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
  1097. }
  1098. static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
  1099. int inplace = ggml_webgpu_tensor_equal(src, dst);
  1100. std::vector<uint32_t> params = {
  1101. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)),
  1102. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  1103. (uint32_t) (src->nb[1] / ggml_type_size(src->type)),
  1104. (uint32_t) (src->nb[2] / ggml_type_size(src->type)),
  1105. (uint32_t) (src->nb[3] / ggml_type_size(src->type)),
  1106. (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
  1107. (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
  1108. (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
  1109. (uint32_t) ggml_nelements(dst),
  1110. (uint32_t) src->ne[0],
  1111. (uint32_t) src->ne[1],
  1112. (uint32_t) src->ne[2],
  1113. *(uint32_t *) dst->op_params, // scale
  1114. *(uint32_t *) &dst->op_params[1] // bias
  1115. };
  1116. std::vector<wgpu::BindGroupEntry> entries = {
  1117. { .binding = 0,
  1118. .buffer = ggml_webgpu_tensor_buf(src),
  1119. .offset = ggml_webgpu_tensor_align_offset(ctx, src),
  1120. .size = ggml_webgpu_tensor_binding_size(ctx, src) }
  1121. };
  1122. if (!inplace) {
  1123. entries.push_back({ .binding = 1,
  1124. .buffer = ggml_webgpu_tensor_buf(dst),
  1125. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  1126. .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
  1127. }
  1128. uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
  1129. return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x);
  1130. }
  1131. static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx,
  1132. ggml_tensor * src0,
  1133. ggml_tensor * src1,
  1134. ggml_tensor * src2,
  1135. ggml_tensor * dst) {
  1136. const int inplace = ggml_webgpu_tensor_equal(src0, dst);
  1137. const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here
  1138. const int has_sink = (src2 != nullptr);
  1139. float max_bias;
  1140. memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
  1141. float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2])));
  1142. float m0 = powf(2.0f, -(max_bias) / n_head_log2);
  1143. float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
  1144. std::vector<uint32_t> params = {
  1145. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
  1146. mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0,
  1147. has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0,
  1148. (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
  1149. (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
  1150. (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
  1151. (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
  1152. mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0,
  1153. mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0,
  1154. mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0,
  1155. (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
  1156. (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
  1157. (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
  1158. (uint32_t) ggml_nelements(dst),
  1159. (uint32_t) src0->ne[0],
  1160. (uint32_t) src0->ne[1],
  1161. (uint32_t) src0->ne[2],
  1162. mask_type < 2 ? (uint32_t) src1->ne[2] : 0,
  1163. mask_type < 2 ? (uint32_t) src1->ne[3] : 0,
  1164. *(uint32_t *) dst->op_params, // scale
  1165. *(uint32_t *) &max_bias,
  1166. *(uint32_t *) &n_head_log2,
  1167. *(uint32_t *) &m0,
  1168. *(uint32_t *) &m1
  1169. };
  1170. std::vector<wgpu::BindGroupEntry> entries = {
  1171. { .binding = 0,
  1172. .buffer = ggml_webgpu_tensor_buf(src0),
  1173. .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
  1174. .size = ggml_webgpu_tensor_binding_size(ctx, src0) }
  1175. };
  1176. uint32_t binding_num = 1;
  1177. if (mask_type < 2) {
  1178. entries.push_back({ .binding = binding_num,
  1179. .buffer = ggml_webgpu_tensor_buf(src1),
  1180. .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
  1181. .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
  1182. binding_num++;
  1183. }
  1184. if (has_sink) {
  1185. entries.push_back({ .binding = binding_num,
  1186. .buffer = ggml_webgpu_tensor_buf(src2),
  1187. .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
  1188. .size = ggml_webgpu_tensor_binding_size(ctx, src2) });
  1189. binding_num++;
  1190. }
  1191. if (!inplace) {
  1192. entries.push_back({ .binding = binding_num,
  1193. .buffer = ggml_webgpu_tensor_buf(dst),
  1194. .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
  1195. .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
  1196. }
  1197. return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries,
  1198. ggml_nrows(dst));
  1199. }
  1200. // Returns the encoded command, or std::nullopt if the operation is a no-op
  1201. static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
  1202. if (ggml_is_empty(node)) {
  1203. return std::nullopt;
  1204. }
  1205. WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")");
  1206. ggml_tensor * src0 = node->src[0];
  1207. ggml_tensor * src1 = node->src[1];
  1208. ggml_tensor * src2 = node->src[2];
  1209. switch (node->op) {
  1210. // no-ops
  1211. case GGML_OP_NONE:
  1212. case GGML_OP_VIEW:
  1213. case GGML_OP_PERMUTE:
  1214. case GGML_OP_TRANSPOSE:
  1215. case GGML_OP_RESHAPE:
  1216. return std::nullopt;
  1217. case GGML_OP_CPY:
  1218. case GGML_OP_CONT:
  1219. return ggml_webgpu_cpy(ctx, src0, node);
  1220. case GGML_OP_SET_ROWS:
  1221. return ggml_webgpu_set_rows(ctx, src0, src1, node);
  1222. case GGML_OP_GET_ROWS:
  1223. return ggml_webgpu_get_rows(ctx, src0, src1, node);
  1224. case GGML_OP_MUL_MAT:
  1225. return ggml_webgpu_mul_mat(ctx, src0, src1, node);
  1226. case GGML_OP_ADD:
  1227. {
  1228. int inplace = ggml_webgpu_tensor_equal(src0, node);
  1229. return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
  1230. }
  1231. case GGML_OP_SUB:
  1232. {
  1233. int inplace = ggml_webgpu_tensor_equal(src0, node);
  1234. return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
  1235. }
  1236. case GGML_OP_MUL:
  1237. {
  1238. int inplace = ggml_webgpu_tensor_equal(src0, node);
  1239. return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
  1240. }
  1241. case GGML_OP_DIV:
  1242. {
  1243. int inplace = ggml_webgpu_tensor_equal(src0, node);
  1244. return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
  1245. }
  1246. case GGML_OP_RMS_NORM:
  1247. return ggml_webgpu_rms_norm(ctx, src0, node);
  1248. case GGML_OP_ROPE:
  1249. return ggml_webgpu_rope(ctx, src0, src1, src2, node);
  1250. case GGML_OP_GLU:
  1251. return ggml_webgpu_glu(ctx, src0, src1, node);
  1252. case GGML_OP_SCALE:
  1253. return ggml_webgpu_scale(ctx, src0, node);
  1254. case GGML_OP_SOFT_MAX:
  1255. return ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
  1256. case GGML_OP_UNARY:
  1257. return ggml_webgpu_unary_op(ctx, src0, node);
  1258. default:
  1259. return std::nullopt;
  1260. }
  1261. }
  1262. static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
  1263. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)");
  1264. ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context);
  1265. webgpu_context ctx = backend_ctx->webgpu_ctx;
  1266. WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
  1267. ctx->inflight_threads++;
  1268. std::vector<webgpu_command> commands;
  1269. std::vector<webgpu_submission_futures> futures;
  1270. for (int i = 0; i < cgraph->n_nodes; i++) {
  1271. if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
  1272. commands.push_back(*cmd);
  1273. }
  1274. // compute the batch size based on the number of inflight threads
  1275. uint32_t inflight_threads = ctx->inflight_threads;
  1276. uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)),
  1277. WEBGPU_COMMAND_SUBMIT_BATCH_SIZE);
  1278. if (commands.size() >= batch_size) {
  1279. futures.push_back(ggml_backend_webgpu_submit(ctx, commands));
  1280. // Process events and check for completed submissions
  1281. ctx->instance.ProcessEvents();
  1282. ggml_backend_webgpu_wait(ctx, futures, false);
  1283. commands.clear();
  1284. }
  1285. }
  1286. if (!commands.empty()) {
  1287. webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
  1288. futures.push_back(new_futures);
  1289. }
  1290. ggml_backend_webgpu_wait(ctx, futures);
  1291. ctx->inflight_threads--;
  1292. WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
  1293. return GGML_STATUS_SUCCESS;
  1294. }
  1295. static ggml_backend_i ggml_backend_webgpu_i = {
  1296. /* .get_name = */ ggml_backend_webgpu_name,
  1297. /* .free = */ ggml_backend_webgpu_free,
  1298. /* .set_tensor_async = */ NULL,
  1299. /* .get_tensor_async = */ NULL,
  1300. /* .cpy_tensor_async = */ NULL,
  1301. /* .synchronize = */ NULL,
  1302. /* .graph_plan_create = */ NULL,
  1303. /* .graph_plan_free = */ NULL,
  1304. /* .graph_plan_update = */ NULL,
  1305. /* .graph_plan_compute = */ NULL,
  1306. /* .graph_compute = */ ggml_backend_webgpu_graph_compute,
  1307. /* .event_record = */ NULL,
  1308. /* .event_wait = */ NULL,
  1309. /* .graph_optimize = */ NULL,
  1310. };
  1311. /* End GGML Backend Interface */
  1312. /* GGML Backend Buffer Interface */
  1313. static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
  1314. ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
  1315. ctx->buffer.Destroy();
  1316. }
  1317. // Returns the "fake" base pointer.
  1318. static void * ggml_backend_webgpu_buffer_get_base(ggml_backend_buffer_t buffer) {
  1319. GGML_UNUSED(buffer);
  1320. return webgpu_ptr_base;
  1321. }
  1322. static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffer,
  1323. ggml_tensor * tensor,
  1324. uint8_t value,
  1325. size_t offset,
  1326. size_t size) {
  1327. if (size == 0) {
  1328. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do.");
  1329. return;
  1330. }
  1331. WEBGPU_CPU_PROFILE_TOTAL_START(memset_tensor);
  1332. ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
  1333. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value
  1334. << ", " << offset << ", " << size << ")");
  1335. size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
  1336. // This is a trick to set all bytes of a u32 to the same 1 byte value.
  1337. uint32_t val32 = (uint32_t) value * 0x01010101;
  1338. ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size);
  1339. WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx);
  1340. }
  1341. static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer,
  1342. ggml_tensor * tensor,
  1343. const void * data,
  1344. size_t offset,
  1345. size_t size) {
  1346. WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor);
  1347. ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
  1348. webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
  1349. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
  1350. << ", " << offset << ", " << size << ")");
  1351. size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
  1352. webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4);
  1353. if (size % 4 != 0) {
  1354. // If size is not a multiple of 4, we need to memset the remaining bytes
  1355. size_t remaining_size = size % 4;
  1356. // pack the remaining bytes into a uint32_t
  1357. uint32_t val32 = 0;
  1358. for (size_t i = 0; i < remaining_size; i++) {
  1359. ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i];
  1360. }
  1361. // memset the remaining bytes
  1362. ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size),
  1363. remaining_size);
  1364. } else {
  1365. // wait for WriteBuffer to complete
  1366. webgpu_ctx->instance.WaitAny(
  1367. webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous,
  1368. [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) {
  1369. if (status != wgpu::QueueWorkDoneStatus::Success) {
  1370. GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n",
  1371. std::string(message).c_str());
  1372. }
  1373. }),
  1374. UINT64_MAX);
  1375. }
  1376. WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx);
  1377. }
  1378. static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer,
  1379. const ggml_tensor * tensor,
  1380. void * data,
  1381. size_t offset,
  1382. size_t size) {
  1383. WEBGPU_CPU_PROFILE_TOTAL_START(get_tensor);
  1384. ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
  1385. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data
  1386. << ", " << offset << ", " << size << ")");
  1387. webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx;
  1388. wgpu::Device device = webgpu_ctx->device;
  1389. size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset;
  1390. size_t final_size = size;
  1391. if (size % 4 != 0) {
  1392. // If size is not a multiple of 4, we need to round it up to the next multiple of 4
  1393. final_size = size + (4 - (size % 4));
  1394. }
  1395. std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex);
  1396. if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) {
  1397. // Create a new staging buffer if it doesn't exist or is too small
  1398. if (webgpu_ctx->get_tensor_staging_buf) {
  1399. webgpu_ctx->get_tensor_staging_buf.Destroy();
  1400. }
  1401. ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size,
  1402. wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf");
  1403. }
  1404. // Copy the data from the buffer to the staging buffer
  1405. wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
  1406. encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size);
  1407. wgpu::CommandBuffer commands = encoder.Finish();
  1408. // Submit the command buffer to the queue
  1409. webgpu_ctx->queue.Submit(1, &commands);
  1410. // Map the staging buffer to read the data
  1411. ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size);
  1412. // Must specify size here since the staging buffer might be larger than the tensor size
  1413. const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size);
  1414. // Copy the data from the mapped range to the output buffer
  1415. std::memcpy(data, mapped_range, size);
  1416. webgpu_ctx->get_tensor_staging_buf.Unmap();
  1417. WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx);
  1418. }
  1419. static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
  1420. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")");
  1421. WEBGPU_CPU_PROFILE_TOTAL_START(clear);
  1422. ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context;
  1423. ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size);
  1424. WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx);
  1425. }
  1426. static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = {
  1427. /* .free_buffer = */ ggml_backend_webgpu_buffer_free_buffer,
  1428. /* .get_base = */ ggml_backend_webgpu_buffer_get_base,
  1429. /* .init_tensor = */ NULL, // TODO: optional, needed?
  1430. /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor,
  1431. /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor,
  1432. /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor,
  1433. /* .cpy_tensor = */ NULL, // TODO: optional, implement this
  1434. /* .clear = */ ggml_backend_webgpu_buffer_clear,
  1435. /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor
  1436. };
  1437. /* End GGML Backend Buffer Interface */
  1438. /* GGML Backend Buffer Type Interface */
  1439. static const char * ggml_backend_webgpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
  1440. ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
  1441. return ctx->device_name.c_str();
  1442. }
  1443. static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
  1444. size_t size) {
  1445. static std::atomic<int> buffer_count;
  1446. int buffer_id = buffer_count++;
  1447. std::string buf_name = "tensor_buf" + std::to_string(buffer_id);
  1448. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes");
  1449. ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
  1450. wgpu::Buffer buf;
  1451. ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT),
  1452. wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
  1453. buf_name.c_str());
  1454. ggml_backend_webgpu_buffer_context * buf_ctx =
  1455. new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name);
  1456. return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size);
  1457. }
  1458. static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
  1459. ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
  1460. return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment;
  1461. }
  1462. // maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding.
  1463. static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
  1464. ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context);
  1465. return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize;
  1466. }
  1467. /* End GGML Backend Buffer Type Interface */
  1468. /* GGML Backend Device Interface */
  1469. static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) {
  1470. ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
  1471. return ctx->device_name.c_str();
  1472. }
  1473. static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) {
  1474. ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
  1475. return ctx->device_desc.c_str();
  1476. }
  1477. static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
  1478. ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
  1479. // TODO: what do we actually want to return here? maxBufferSize might not be the full available memory.
  1480. *free = ctx->webgpu_ctx->limits.maxBufferSize;
  1481. *total = ctx->webgpu_ctx->limits.maxBufferSize;
  1482. }
  1483. static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) {
  1484. GGML_UNUSED(dev);
  1485. return GGML_BACKEND_DEVICE_TYPE_GPU;
  1486. }
  1487. static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
  1488. props->name = ggml_backend_webgpu_device_get_name(dev);
  1489. props->description = ggml_backend_webgpu_device_get_description(dev);
  1490. props->type = ggml_backend_webgpu_device_get_type(dev);
  1491. ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
  1492. props->caps = {
  1493. /* .async = */ false,
  1494. /* .host_buffer = */ false,
  1495. /* .buffer_from_host_ptr = */ false,
  1496. /* .events = */ false,
  1497. };
  1498. }
  1499. static ggml_guid_t ggml_backend_webgpu_guid(void) {
  1500. static const char * guid_str = "__ggml_webgpu :)";
  1501. return reinterpret_cast<ggml_guid_t>((void *) guid_str);
  1502. }
  1503. // Workgroup size is a common constant
  1504. static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) {
  1505. std::vector<wgpu::ConstantEntry> constants(1);
  1506. constants[0].key = "wg_size";
  1507. constants[0].value = wg_size;
  1508. return constants;
  1509. }
  1510. static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) {
  1511. // we use the maximum workgroup size for the memset pipeline
  1512. size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension;
  1513. // Size the bytes_per_thread so that the largest buffer size can be handled
  1514. webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads);
  1515. std::vector<wgpu::ConstantEntry> constants(2);
  1516. constants[0].key = "wg_size";
  1517. constants[0].value = WEBGPU_MAX_WG_SIZE;
  1518. constants[1].key = "bytes_per_thread";
  1519. constants[1].value = webgpu_ctx->memset_bytes_per_thread;
  1520. webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants);
  1521. }
  1522. static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
  1523. // Q4/Q5/Q8 classic quantizations
  1524. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] =
  1525. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32");
  1526. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] =
  1527. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32");
  1528. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] =
  1529. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32");
  1530. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] =
  1531. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32");
  1532. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] =
  1533. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32");
  1534. // K-quantizations
  1535. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] =
  1536. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32");
  1537. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] =
  1538. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32");
  1539. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] =
  1540. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32");
  1541. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] =
  1542. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32");
  1543. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] =
  1544. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32");
  1545. // IQ quantizations (2-, 3-, 4-bit variants)
  1546. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] =
  1547. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32");
  1548. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] =
  1549. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32");
  1550. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] =
  1551. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32");
  1552. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] =
  1553. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32");
  1554. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] =
  1555. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32");
  1556. // 1-bit and 4-bit IQ variants
  1557. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] =
  1558. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32");
  1559. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] =
  1560. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32");
  1561. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] =
  1562. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32");
  1563. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] =
  1564. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32");
  1565. std::string proc_mul_mat_f32_f32;
  1566. std::string proc_mul_mat_f32_f32_vec;
  1567. std::string proc_mul_mat_f16_f32;
  1568. std::string proc_mul_mat_f16_f32_vec;
  1569. std::string proc_mul_mat_f16_f16;
  1570. std::string proc_mul_mat_f16_f16_vec;
  1571. std::string proc_mul_mat_q4_0_f32;
  1572. std::string proc_mul_mat_q4_0_f32_vec;
  1573. std::vector<wgpu::ConstantEntry> mul_mat_constants;
  1574. #ifndef __EMSCRIPTEN__
  1575. if (webgpu_ctx->supports_subgroup_matrix) {
  1576. std::map<std::string, std::string> sg_matrix_repls;
  1577. sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
  1578. sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K);
  1579. sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
  1580. sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
  1581. sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
  1582. sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
  1583. sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
  1584. sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
  1585. sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
  1586. proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
  1587. proc_mul_mat_f32_f32_vec =
  1588. ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls);
  1589. proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls);
  1590. proc_mul_mat_f16_f32_vec =
  1591. ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls);
  1592. proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls);
  1593. proc_mul_mat_f16_f16_vec =
  1594. ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls);
  1595. proc_mul_mat_q4_0_f32 =
  1596. ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls);
  1597. proc_mul_mat_q4_0_f32_vec =
  1598. ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls);
  1599. } else {
  1600. #endif
  1601. mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K });
  1602. mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M });
  1603. mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N });
  1604. std::map<std::string, std::string> reg_repls;
  1605. reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M);
  1606. reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N);
  1607. proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls);
  1608. proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls);
  1609. proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls);
  1610. proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls);
  1611. proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls);
  1612. proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls);
  1613. proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls);
  1614. proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls);
  1615. #ifndef __EMSCRIPTEN__
  1616. }
  1617. #endif
  1618. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
  1619. webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants);
  1620. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
  1621. webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants);
  1622. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
  1623. webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants);
  1624. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
  1625. webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants);
  1626. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
  1627. webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants);
  1628. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
  1629. webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants);
  1630. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
  1631. webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants);
  1632. webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
  1633. webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants);
  1634. std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3);
  1635. mul_mat_vec_constants[0].key = "WORKGROUP_SIZE";
  1636. mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE;
  1637. mul_mat_vec_constants[1].key = "TILE_K";
  1638. mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K;
  1639. mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG";
  1640. mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG;
  1641. webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
  1642. webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants);
  1643. webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
  1644. webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants);
  1645. webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
  1646. webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants);
  1647. webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
  1648. webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants);
  1649. webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline(
  1650. webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants);
  1651. webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
  1652. webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants);
  1653. webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline(
  1654. webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants);
  1655. }
  1656. static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) {
  1657. webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline(
  1658. webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
  1659. webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline(
  1660. webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE));
  1661. }
  1662. static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) {
  1663. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  1664. webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] =
  1665. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants);
  1666. webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] =
  1667. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants);
  1668. webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] =
  1669. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants);
  1670. webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] =
  1671. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants);
  1672. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] =
  1673. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants);
  1674. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] =
  1675. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants);
  1676. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] =
  1677. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants);
  1678. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] =
  1679. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants);
  1680. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] =
  1681. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants);
  1682. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] =
  1683. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants);
  1684. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] =
  1685. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants);
  1686. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] =
  1687. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants);
  1688. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] =
  1689. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants);
  1690. webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] =
  1691. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants);
  1692. webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] =
  1693. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants);
  1694. webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] =
  1695. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants);
  1696. webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] =
  1697. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants);
  1698. webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] =
  1699. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants);
  1700. webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] =
  1701. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants);
  1702. webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] =
  1703. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants);
  1704. webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] =
  1705. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants);
  1706. webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] =
  1707. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants);
  1708. webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] =
  1709. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants);
  1710. }
  1711. static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
  1712. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  1713. webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] =
  1714. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants);
  1715. webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] =
  1716. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants);
  1717. webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] =
  1718. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants);
  1719. webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] =
  1720. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
  1721. }
  1722. static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
  1723. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  1724. webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
  1725. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants);
  1726. webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
  1727. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants);
  1728. webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
  1729. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
  1730. webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
  1731. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
  1732. }
  1733. static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
  1734. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  1735. webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
  1736. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants);
  1737. webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
  1738. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants);
  1739. webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
  1740. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
  1741. webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
  1742. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
  1743. }
  1744. static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
  1745. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  1746. webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
  1747. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants);
  1748. webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
  1749. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants);
  1750. webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
  1751. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
  1752. webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
  1753. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
  1754. }
  1755. static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
  1756. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  1757. webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
  1758. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants);
  1759. webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
  1760. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants);
  1761. webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
  1762. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
  1763. webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
  1764. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
  1765. }
  1766. static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
  1767. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
  1768. webgpu_ctx->rms_norm_pipelines[0] =
  1769. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants);
  1770. webgpu_ctx->rms_norm_pipelines[1] =
  1771. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants);
  1772. }
  1773. static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) {
  1774. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  1775. webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] =
  1776. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants);
  1777. webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] =
  1778. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants);
  1779. webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] =
  1780. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants);
  1781. webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] =
  1782. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants);
  1783. webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] =
  1784. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants);
  1785. webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] =
  1786. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants);
  1787. webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] =
  1788. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants);
  1789. webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] =
  1790. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants);
  1791. }
  1792. static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) {
  1793. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  1794. // REGLU
  1795. webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] =
  1796. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants);
  1797. webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] =
  1798. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants);
  1799. webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] =
  1800. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants);
  1801. webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] =
  1802. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants);
  1803. // GEGLU
  1804. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] =
  1805. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants);
  1806. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] =
  1807. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants);
  1808. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] =
  1809. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants);
  1810. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] =
  1811. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants);
  1812. // SWIGLU
  1813. webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] =
  1814. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants);
  1815. webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] =
  1816. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants);
  1817. webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] =
  1818. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants);
  1819. webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] =
  1820. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants);
  1821. // SWIGLU_OAI
  1822. webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] =
  1823. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants);
  1824. webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] =
  1825. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants);
  1826. // GEGLU_ERF
  1827. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] =
  1828. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants);
  1829. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] =
  1830. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants);
  1831. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] =
  1832. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants);
  1833. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] =
  1834. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants);
  1835. // GEGLU_QUICK
  1836. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] =
  1837. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants);
  1838. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] =
  1839. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants);
  1840. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] =
  1841. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants);
  1842. webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] =
  1843. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants);
  1844. }
  1845. static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) {
  1846. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  1847. // ABS
  1848. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] =
  1849. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants);
  1850. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] =
  1851. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants);
  1852. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] =
  1853. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants);
  1854. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] =
  1855. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants);
  1856. // SGN
  1857. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] =
  1858. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants);
  1859. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] =
  1860. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants);
  1861. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] =
  1862. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants);
  1863. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] =
  1864. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants);
  1865. // NEG
  1866. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] =
  1867. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants);
  1868. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] =
  1869. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants);
  1870. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] =
  1871. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants);
  1872. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] =
  1873. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants);
  1874. // STEP
  1875. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] =
  1876. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants);
  1877. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] =
  1878. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants);
  1879. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] =
  1880. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants);
  1881. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] =
  1882. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants);
  1883. // TANH
  1884. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] =
  1885. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants);
  1886. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] =
  1887. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants);
  1888. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] =
  1889. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants);
  1890. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] =
  1891. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants);
  1892. // ELU
  1893. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] =
  1894. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants);
  1895. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] =
  1896. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants);
  1897. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] =
  1898. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants);
  1899. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] =
  1900. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants);
  1901. // RELU
  1902. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] =
  1903. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants);
  1904. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] =
  1905. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants);
  1906. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] =
  1907. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants);
  1908. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] =
  1909. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants);
  1910. // SIGMOID
  1911. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] =
  1912. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants);
  1913. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] =
  1914. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants);
  1915. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] =
  1916. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants);
  1917. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] =
  1918. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants);
  1919. // GELU
  1920. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] =
  1921. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants);
  1922. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] =
  1923. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants);
  1924. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] =
  1925. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants);
  1926. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] =
  1927. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants);
  1928. // GELU_QUICK
  1929. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] =
  1930. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants);
  1931. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] =
  1932. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants);
  1933. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
  1934. webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants);
  1935. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
  1936. webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants);
  1937. // SILU
  1938. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] =
  1939. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants);
  1940. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] =
  1941. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants);
  1942. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] =
  1943. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants);
  1944. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] =
  1945. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants);
  1946. // HARDSWISH
  1947. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] =
  1948. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants);
  1949. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] =
  1950. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants);
  1951. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] =
  1952. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants);
  1953. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] =
  1954. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants);
  1955. // HARDSIGMOID
  1956. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] =
  1957. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants);
  1958. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] =
  1959. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants);
  1960. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline(
  1961. webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants);
  1962. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(
  1963. webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants);
  1964. // EXP
  1965. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] =
  1966. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants);
  1967. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] =
  1968. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants);
  1969. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] =
  1970. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants);
  1971. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] =
  1972. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants);
  1973. // GELU_ERF
  1974. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] =
  1975. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants);
  1976. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] =
  1977. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants);
  1978. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] =
  1979. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants);
  1980. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] =
  1981. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants);
  1982. // XIELU
  1983. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] =
  1984. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants);
  1985. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] =
  1986. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants);
  1987. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] =
  1988. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants);
  1989. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] =
  1990. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants);
  1991. // CEIL
  1992. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] =
  1993. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants);
  1994. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] =
  1995. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants);
  1996. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] =
  1997. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants);
  1998. webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] =
  1999. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants);
  2000. }
  2001. static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) {
  2002. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
  2003. webgpu_ctx->scale_pipelines[0] =
  2004. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants);
  2005. webgpu_ctx->scale_pipelines[1] =
  2006. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants);
  2007. }
  2008. static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
  2009. std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
  2010. // f32 (no mask)
  2011. webgpu_ctx->soft_max_pipelines[2][0][0] =
  2012. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants);
  2013. webgpu_ctx->soft_max_pipelines[2][0][1] =
  2014. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants);
  2015. webgpu_ctx->soft_max_pipelines[2][1][0] =
  2016. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants);
  2017. webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline(
  2018. webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants);
  2019. // f32 mask (mask_type = 0)
  2020. webgpu_ctx->soft_max_pipelines[0][0][0] =
  2021. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants);
  2022. webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline(
  2023. webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants);
  2024. webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline(
  2025. webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants);
  2026. webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline(
  2027. webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants);
  2028. // f16 mask (mask_type = 1)
  2029. webgpu_ctx->soft_max_pipelines[1][0][0] =
  2030. ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants);
  2031. webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline(
  2032. webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants);
  2033. webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline(
  2034. webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants);
  2035. webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline(
  2036. webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
  2037. }
  2038. static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
  2039. GGML_UNUSED(params);
  2040. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()");
  2041. ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
  2042. webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx;
  2043. static ggml_backend_webgpu_context backend_ctx;
  2044. backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name;
  2045. backend_ctx.webgpu_ctx = webgpu_ctx;
  2046. // See GGML Backend Interface section
  2047. static ggml_backend backend = {
  2048. /* .guid = */ ggml_backend_webgpu_guid(),
  2049. /* .interface = */ ggml_backend_webgpu_i,
  2050. /* .device = */ dev,
  2051. /* .context = */ &backend_ctx,
  2052. };
  2053. return &backend;
  2054. }
  2055. static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) {
  2056. // See GGML Backend Buffer Type Interface section
  2057. static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = {
  2058. /* .iface = */ {
  2059. /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name,
  2060. /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer,
  2061. /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment,
  2062. /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size,
  2063. /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
  2064. /* .is_host = */ NULL, // defaults to false
  2065. },
  2066. /* .device = */
  2067. dev,
  2068. /* .context = */ NULL,
  2069. };
  2070. return &ggml_backend_webgpu_buffer_type;
  2071. }
  2072. static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
  2073. GGML_UNUSED(dev);
  2074. return buft->iface.get_name == ggml_backend_webgpu_buffer_type_get_name;
  2075. }
  2076. static bool ggml_webgpu_supported_qtype(ggml_type type) {
  2077. switch (type) {
  2078. case GGML_TYPE_Q4_0:
  2079. case GGML_TYPE_Q4_1:
  2080. case GGML_TYPE_Q5_0:
  2081. case GGML_TYPE_Q5_1:
  2082. case GGML_TYPE_Q8_0:
  2083. case GGML_TYPE_Q2_K:
  2084. case GGML_TYPE_Q3_K:
  2085. case GGML_TYPE_Q4_K:
  2086. case GGML_TYPE_Q5_K:
  2087. case GGML_TYPE_Q6_K:
  2088. case GGML_TYPE_IQ2_XXS:
  2089. case GGML_TYPE_IQ2_XS:
  2090. case GGML_TYPE_IQ2_S:
  2091. case GGML_TYPE_IQ3_XXS:
  2092. case GGML_TYPE_IQ3_S:
  2093. case GGML_TYPE_IQ1_S:
  2094. case GGML_TYPE_IQ1_M:
  2095. case GGML_TYPE_IQ4_NL:
  2096. case GGML_TYPE_IQ4_XS:
  2097. return true;
  2098. default:
  2099. return false;
  2100. }
  2101. }
  2102. static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
  2103. ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context);
  2104. webgpu_context webgpu_ctx = ctx->webgpu_ctx;
  2105. ggml_tensor * src0 = op->src[0];
  2106. ggml_tensor * src1 = op->src[1];
  2107. ggml_tensor * src2 = op->src[2];
  2108. // on smaller devices (or CI), tensors may be larger than the max storage buffer size
  2109. if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
  2110. (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
  2111. (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
  2112. return false;
  2113. }
  2114. bool supports_op = false;
  2115. switch (op->op) {
  2116. case GGML_OP_NONE:
  2117. case GGML_OP_VIEW:
  2118. case GGML_OP_PERMUTE:
  2119. case GGML_OP_TRANSPOSE:
  2120. case GGML_OP_RESHAPE:
  2121. supports_op = true;
  2122. break;
  2123. case GGML_OP_ADD:
  2124. case GGML_OP_SUB:
  2125. case GGML_OP_MUL:
  2126. case GGML_OP_DIV:
  2127. // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
  2128. // see https://github.com/ggml-org/llama.cpp/pull/16857
  2129. supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
  2130. (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
  2131. break;
  2132. case GGML_OP_CPY:
  2133. case GGML_OP_CONT:
  2134. supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
  2135. (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
  2136. break;
  2137. case GGML_OP_SET_ROWS:
  2138. supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64);
  2139. break;
  2140. case GGML_OP_GET_ROWS:
  2141. if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 ||
  2142. ggml_webgpu_supported_qtype(src0->type)) {
  2143. supports_op = (op->type == GGML_TYPE_F32);
  2144. }
  2145. break;
  2146. case GGML_OP_MUL_MAT:
  2147. {
  2148. switch (src1->type) {
  2149. case GGML_TYPE_F16:
  2150. supports_op |= (src0->type == GGML_TYPE_F16);
  2151. break;
  2152. case GGML_TYPE_F32:
  2153. switch (src0->type) {
  2154. case GGML_TYPE_F32:
  2155. case GGML_TYPE_F16:
  2156. case GGML_TYPE_Q4_0:
  2157. case GGML_TYPE_Q4_1:
  2158. case GGML_TYPE_Q5_0:
  2159. case GGML_TYPE_Q5_1:
  2160. case GGML_TYPE_Q8_0:
  2161. case GGML_TYPE_Q2_K:
  2162. case GGML_TYPE_Q3_K:
  2163. case GGML_TYPE_Q4_K:
  2164. case GGML_TYPE_Q5_K:
  2165. case GGML_TYPE_Q6_K:
  2166. case GGML_TYPE_IQ2_XXS:
  2167. case GGML_TYPE_IQ2_XS:
  2168. case GGML_TYPE_IQ2_S:
  2169. case GGML_TYPE_IQ3_XXS:
  2170. case GGML_TYPE_IQ3_S:
  2171. case GGML_TYPE_IQ1_S:
  2172. case GGML_TYPE_IQ1_M:
  2173. case GGML_TYPE_IQ4_NL:
  2174. case GGML_TYPE_IQ4_XS:
  2175. supports_op = true;
  2176. break;
  2177. default:
  2178. break;
  2179. }
  2180. default:
  2181. break;
  2182. }
  2183. break;
  2184. }
  2185. case GGML_OP_RMS_NORM:
  2186. supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
  2187. break;
  2188. case GGML_OP_ROPE:
  2189. supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
  2190. break;
  2191. case GGML_OP_GLU:
  2192. switch (ggml_get_glu_op(op)) {
  2193. case GGML_GLU_OP_REGLU:
  2194. case GGML_GLU_OP_GEGLU:
  2195. case GGML_GLU_OP_SWIGLU:
  2196. case GGML_GLU_OP_GEGLU_ERF:
  2197. case GGML_GLU_OP_GEGLU_QUICK:
  2198. supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
  2199. break;
  2200. case GGML_GLU_OP_SWIGLU_OAI:
  2201. supports_op = op->type == GGML_TYPE_F32;
  2202. break;
  2203. default:
  2204. break;
  2205. }
  2206. break;
  2207. case GGML_OP_SCALE:
  2208. supports_op = op->type == GGML_TYPE_F32;
  2209. break;
  2210. case GGML_OP_SOFT_MAX:
  2211. supports_op = op->type == GGML_TYPE_F32;
  2212. break;
  2213. case GGML_OP_UNARY:
  2214. {
  2215. const ggml_unary_op UNARY_OP = ggml_get_unary_op(op);
  2216. switch (UNARY_OP) {
  2217. case GGML_UNARY_OP_ABS:
  2218. case GGML_UNARY_OP_SGN:
  2219. case GGML_UNARY_OP_NEG:
  2220. case GGML_UNARY_OP_STEP:
  2221. case GGML_UNARY_OP_TANH:
  2222. case GGML_UNARY_OP_ELU:
  2223. case GGML_UNARY_OP_RELU:
  2224. case GGML_UNARY_OP_SIGMOID:
  2225. case GGML_UNARY_OP_GELU:
  2226. case GGML_UNARY_OP_GELU_QUICK:
  2227. case GGML_UNARY_OP_SILU:
  2228. case GGML_UNARY_OP_HARDSWISH:
  2229. case GGML_UNARY_OP_HARDSIGMOID:
  2230. case GGML_UNARY_OP_EXP:
  2231. case GGML_UNARY_OP_GELU_ERF:
  2232. case GGML_UNARY_OP_XIELU:
  2233. case GGML_UNARY_OP_CEIL:
  2234. supports_op = supports_op =
  2235. (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
  2236. break;
  2237. default:
  2238. break;
  2239. }
  2240. }
  2241. break;
  2242. default:
  2243. break;
  2244. }
  2245. if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize ||
  2246. (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
  2247. (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) ||
  2248. (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) {
  2249. supports_op = false;
  2250. WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: ");
  2251. }
  2252. if (!supports_op) {
  2253. WEBGPU_LOG_DEBUG("ggml_webgpu op not supported: "
  2254. << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
  2255. << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
  2256. << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
  2257. } else {
  2258. WEBGPU_LOG_DEBUG("ggml_webgpu op supported: "
  2259. << ggml_op_name(op->op) << " with types dst: " << ggml_type_name(op->type)
  2260. << ", src0: " << (op->src[0] ? ggml_type_name(op->src[0]->type) : "null")
  2261. << ", src1: " << (op->src[1] ? ggml_type_name(op->src[1]->type) : "null"));
  2262. }
  2263. return supports_op;
  2264. }
  2265. static struct ggml_backend_device_i ggml_backend_webgpu_device_i = {
  2266. /* .get_name = */ ggml_backend_webgpu_device_get_name,
  2267. /* .get_description = */ ggml_backend_webgpu_device_get_description,
  2268. /* .get_memory = */ ggml_backend_webgpu_device_get_memory,
  2269. /* .get_type = */ ggml_backend_webgpu_device_get_type,
  2270. /* .get_props = */ ggml_backend_webgpu_device_get_props,
  2271. /* .init_backend = */ ggml_backend_webgpu_device_init,
  2272. /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type,
  2273. /* .get_host_buffer_type = */ NULL,
  2274. /* .buffer_from_host_ptr = */ NULL,
  2275. /* .supports_op = */ ggml_backend_webgpu_device_supports_op,
  2276. /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft,
  2277. /* .offload_op = */ NULL,
  2278. /* .event_new = */ NULL,
  2279. /* .event_free = */ NULL,
  2280. /* .event_synchronize = */ NULL,
  2281. };
  2282. /* End GGML Backend Device Interface */
  2283. /* GGML Backend Registration Interface */
  2284. static const char * ggml_backend_webgpu_reg_get_name(ggml_backend_reg_t reg) {
  2285. ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
  2286. return ctx->name;
  2287. }
  2288. static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
  2289. ggml_backend_webgpu_reg_context * ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
  2290. return ctx->device_count;
  2291. }
  2292. // TODO: Does this need to be thread safe? Is it only called once?
  2293. // Only one device is supported for now
  2294. static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
  2295. GGML_ASSERT(index == 0);
  2296. WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");
  2297. WEBGPU_CPU_PROFILE_TOTAL_START(reg_get_device);
  2298. ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context);
  2299. webgpu_context ctx = reg_ctx->webgpu_ctx;
  2300. wgpu::RequestAdapterOptions options = {};
  2301. #ifndef __EMSCRIPTEN__
  2302. // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215
  2303. const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" };
  2304. wgpu::DawnTogglesDescriptor adapterTogglesDesc;
  2305. adapterTogglesDesc.enabledToggles = adapterEnabledToggles;
  2306. adapterTogglesDesc.enabledToggleCount = 2;
  2307. options.nextInChain = &adapterTogglesDesc;
  2308. #endif
  2309. ctx->instance.WaitAny(ctx->instance.RequestAdapter(
  2310. &options, wgpu::CallbackMode::AllowSpontaneous,
  2311. [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
  2312. if (status != wgpu::RequestAdapterStatus::Success) {
  2313. GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
  2314. return;
  2315. }
  2316. ctx->adapter = std::move(adapter);
  2317. }),
  2318. UINT64_MAX);
  2319. GGML_ASSERT(ctx->adapter != nullptr);
  2320. ctx->adapter.GetLimits(&ctx->limits);
  2321. wgpu::AdapterInfo info{};
  2322. #ifndef __EMSCRIPTEN__
  2323. wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{};
  2324. if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
  2325. info.nextInChain = &subgroup_matrix_configs;
  2326. }
  2327. #endif
  2328. ctx->adapter.GetInfo(&info);
  2329. wgpu::SupportedFeatures features;
  2330. ctx->adapter.GetFeatures(&features);
  2331. // we require f16 support
  2332. GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16));
  2333. #ifndef __EMSCRIPTEN__
  2334. // Only support square f16 matrices of size 8 or 16 for now
  2335. bool valid_subgroup_matrix_config = false;
  2336. if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) {
  2337. for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) {
  2338. const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i];
  2339. if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
  2340. config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
  2341. config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
  2342. ctx->subgroup_matrix_config = config;
  2343. valid_subgroup_matrix_config = true;
  2344. break;
  2345. }
  2346. }
  2347. }
  2348. ctx->supports_subgroup_matrix = valid_subgroup_matrix_config;
  2349. #endif
  2350. // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
  2351. // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
  2352. ctx->subgroup_size = info.subgroupMaxSize;
  2353. // Initialize device
  2354. std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
  2355. #ifndef __EMSCRIPTEN__
  2356. required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization);
  2357. if (ctx->supports_subgroup_matrix) {
  2358. required_features.push_back(wgpu::FeatureName::Subgroups);
  2359. required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix);
  2360. }
  2361. #endif
  2362. #ifdef GGML_WEBGPU_GPU_PROFILE
  2363. required_features.push_back(wgpu::FeatureName::TimestampQuery);
  2364. #endif
  2365. wgpu::DeviceDescriptor dev_desc;
  2366. dev_desc.requiredLimits = &ctx->limits;
  2367. dev_desc.requiredFeatures = required_features.data();
  2368. dev_desc.requiredFeatureCount = required_features.size();
  2369. dev_desc.SetDeviceLostCallback(
  2370. wgpu::CallbackMode::AllowSpontaneous,
  2371. [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
  2372. GGML_UNUSED(device);
  2373. GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
  2374. std::string(message).c_str());
  2375. });
  2376. dev_desc.SetUncapturedErrorCallback(
  2377. [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
  2378. GGML_UNUSED(device);
  2379. GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason),
  2380. std::string(message).c_str());
  2381. });
  2382. #ifndef __EMSCRIPTEN__
  2383. // Enable Dawn-specific toggles to increase native performance
  2384. // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these,
  2385. // only for native performance?
  2386. const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init",
  2387. "disable_polyfills_on_integer_div_and_mod" };
  2388. const char * const deviceDisabledToggles[] = { "timestamp_quantization" };
  2389. wgpu::DawnTogglesDescriptor deviceTogglesDesc;
  2390. deviceTogglesDesc.enabledToggles = deviceEnabledToggles;
  2391. deviceTogglesDesc.enabledToggleCount = 4;
  2392. deviceTogglesDesc.disabledToggles = deviceDisabledToggles;
  2393. deviceTogglesDesc.disabledToggleCount = 1;
  2394. dev_desc.nextInChain = &deviceTogglesDesc;
  2395. #endif
  2396. ctx->instance.WaitAny(ctx->adapter.RequestDevice(
  2397. &dev_desc, wgpu::CallbackMode::AllowSpontaneous,
  2398. [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) {
  2399. if (status != wgpu::RequestDeviceStatus::Success) {
  2400. GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n",
  2401. std::string(message).c_str());
  2402. return;
  2403. }
  2404. ctx->device = std::move(device);
  2405. }),
  2406. UINT64_MAX);
  2407. GGML_ASSERT(ctx->device != nullptr);
  2408. // Initialize (compute) queue
  2409. ctx->queue = ctx->device.GetQueue();
  2410. // Create buffer pool for shader parameters
  2411. ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES,
  2412. wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform,
  2413. wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite);
  2414. #ifdef GGML_WEBGPU_GPU_PROFILE
  2415. // Initialize buffer pool for timestamp queries (profiling)
  2416. ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS,
  2417. WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES,
  2418. wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc,
  2419. wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst);
  2420. #endif
  2421. ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES,
  2422. wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage,
  2423. wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead);
  2424. ggml_webgpu_init_memset_pipeline(ctx);
  2425. ggml_webgpu_init_mul_mat_pipeline(ctx);
  2426. ggml_webgpu_init_set_rows_pipeline(ctx);
  2427. ggml_webgpu_init_get_rows_pipeline(ctx);
  2428. ggml_webgpu_init_cpy_pipeline(ctx);
  2429. ggml_webgpu_init_add_pipeline(ctx);
  2430. ggml_webgpu_init_sub_pipeline(ctx);
  2431. ggml_webgpu_init_mul_pipeline(ctx);
  2432. ggml_webgpu_init_div_pipeline(ctx);
  2433. ggml_webgpu_init_rms_norm_pipeline(ctx);
  2434. ggml_webgpu_init_rope_pipeline(ctx);
  2435. ggml_webgpu_init_glu_pipeline(ctx);
  2436. ggml_webgpu_init_scale_pipeline(ctx);
  2437. ggml_webgpu_init_soft_max_pipeline(ctx);
  2438. ggml_webgpu_init_unary_pipeline(ctx);
  2439. #ifdef GGML_WEBGPU_DEBUG
  2440. // Initialize debug buffers
  2441. ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
  2442. wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf");
  2443. ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t),
  2444. wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf");
  2445. #endif
  2446. static ggml_backend_webgpu_device_context device_ctx;
  2447. device_ctx.webgpu_ctx = ctx;
  2448. device_ctx.device_name = GGML_WEBGPU_NAME;
  2449. device_ctx.device_desc = info.description;
  2450. GGML_LOG_INFO(
  2451. "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | "
  2452. "device_desc: %s\n",
  2453. info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID,
  2454. std::string(info.device).c_str(), std::string(info.description).c_str());
  2455. // See GGML Backend Device Interface section
  2456. static ggml_backend_device device = {
  2457. /* .iface = */ ggml_backend_webgpu_device_i,
  2458. /* .reg = */ reg,
  2459. /* .context = */ &device_ctx,
  2460. };
  2461. WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx);
  2462. return &device;
  2463. }
  2464. static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = {
  2465. /* .get_name = */ ggml_backend_webgpu_reg_get_name,
  2466. /* .get_device_count = */ ggml_backend_webgpu_reg_get_device_count,
  2467. /* .get_device = */ ggml_backend_webgpu_reg_get_device,
  2468. /* .get_proc_address = */ NULL,
  2469. };
  2470. /* End GGML Backend Registration Interface */
  2471. ggml_backend_reg_t ggml_backend_webgpu_reg() {
  2472. WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()");
  2473. webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>();
  2474. static ggml_backend_webgpu_reg_context ctx;
  2475. ctx.webgpu_ctx = webgpu_ctx;
  2476. ctx.name = GGML_WEBGPU_NAME;
  2477. ctx.device_count = 1;
  2478. wgpu::InstanceDescriptor instance_descriptor{};
  2479. std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny };
  2480. instance_descriptor.requiredFeatures = instance_features.data();
  2481. instance_descriptor.requiredFeatureCount = instance_features.size();
  2482. #ifndef __EMSCRIPTEN__
  2483. const char * const instanceEnabledToggles[] = { "allow_unsafe_apis" };
  2484. wgpu::DawnTogglesDescriptor instanceTogglesDesc;
  2485. instanceTogglesDesc.enabledToggles = instanceEnabledToggles;
  2486. instanceTogglesDesc.enabledToggleCount = 1;
  2487. instance_descriptor.nextInChain = &instanceTogglesDesc;
  2488. #endif
  2489. webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor);
  2490. #ifdef __EMSCRIPTEN__
  2491. if (webgpu_ctx->instance == nullptr) {
  2492. GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n");
  2493. return nullptr;
  2494. }
  2495. #endif
  2496. GGML_ASSERT(webgpu_ctx->instance != nullptr);
  2497. static ggml_backend_reg reg = {
  2498. /* .api_version = */ GGML_BACKEND_API_VERSION,
  2499. /* .iface = */ ggml_backend_webgpu_reg_i,
  2500. /* .context = */ &ctx,
  2501. };
  2502. return &reg;
  2503. }
  2504. ggml_backend_t ggml_backend_webgpu_init(void) {
  2505. ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0);
  2506. return ggml_backend_webgpu_device_init(dev, nullptr);
  2507. }
  2508. GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg)