aclnn_ops.cpp 123 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082
  1. /*
  2. * Copyright (c) 2023-2024 The ggml authors
  3. *
  4. * Permission is hereby granted, free of charge, to any person obtaining a copy
  5. * of this software and associated documentation files (the "Software"), to
  6. * deal in the Software without restriction, including without limitation the
  7. * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
  8. * sell copies of the Software, and to permit persons to whom the Software is
  9. * furnished to do so, subject to the following conditions:
  10. *
  11. * The above copyright notice and this permission notice shall be included in
  12. * all copies or substantial portions of the Software.
  13. *
  14. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  15. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  16. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  17. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  18. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  19. * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
  20. * IN THE SOFTWARE.
  21. */
  22. #include "aclnn_ops.h"
  23. #include <aclnnop/aclnn_avgpool2d.h>
  24. #include <aclnnop/aclnn_cast.h>
  25. #include <aclnnop/aclnn_constant_pad_nd.h>
  26. #include <aclnnop/aclnn_copy.h>
  27. #include <aclnnop/aclnn_cos.h>
  28. #include <aclnnop/aclnn_exp.h>
  29. #include <aclnnop/aclnn_fill_scalar.h>
  30. #include <aclnnop/aclnn_group_norm.h>
  31. #include <aclnnop/aclnn_index_fill_tensor.h>
  32. #include <aclnnop/aclnn_layer_norm.h>
  33. #include <aclnnop/aclnn_matmul.h>
  34. #include <aclnnop/aclnn_max_pool.h>
  35. #include <aclnnop/aclnn_permute.h>
  36. #include <aclnnop/aclnn_pow_tensor_tensor.h>
  37. #include <aclnnop/aclnn_reduce_sum.h>
  38. #include <aclnnop/aclnn_repeat.h>
  39. #include <aclnnop/aclnn_repeat_interleave.h>
  40. #include <aclnnop/aclnn_roll.h>
  41. #include <aclnnop/aclnn_sin.h>
  42. #include <aclnnop/aclnn_softmax.h>
  43. #include <aclnnop/aclnn_tril.h>
  44. #include <aclnnop/aclnn_triu.h>
  45. #include <aclnnop/aclnn_upsample_nearest_2d.h>
  46. #include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
  47. #include <float.h>
  48. #include <cmath>
  49. #include <cstring>
  50. #include <exception>
  51. #include <vector>
  52. #include "kernels/ascendc_kernels.h"
  53. #define GGML_COMMON_DECL_C
  54. #include "../ggml-common.h"
  55. /**
  56. * @brief Repeats elements of a tensor along each dimension according to the
  57. * specified repeat array.
  58. *
  59. * @param ctx The context for the CANN backend operations.
  60. * @param acl_src The source tensor to be repeated.
  61. * @param acl_dst The destination tensor after repeating.
  62. * @param repeat_array The array specifying the number of repetitions along each
  63. * dimension.
  64. */
  65. static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  66. aclTensor* acl_dst, int64_t* repeat_array) {
  67. // repeat tensor along each dim with repeat_array
  68. aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS);
  69. uint64_t workspaceSize = 0;
  70. aclOpExecutor* executor;
  71. void* workspaceAddr = nullptr;
  72. ACL_CHECK(aclnnRepeatGetWorkspaceSize(acl_src, repeats, acl_dst,
  73. &workspaceSize, &executor));
  74. if (workspaceSize > 0) {
  75. // Memory from allocator will "free" immediately, and this memory
  76. // will be alloced to other pointers, but it won't access before
  77. // this async task end because all tasks in same stream will execute
  78. // in queue.
  79. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  80. workspaceAddr = workspace_allocator.get();
  81. }
  82. ACL_CHECK(
  83. aclnnRepeat(workspaceAddr, workspaceSize, executor, ctx.stream()));
  84. ACL_CHECK(aclDestroyIntArray(repeats));
  85. }
  86. void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  87. ggml_tensor* src = dst->src[0];
  88. GGML_ASSERT(ggml_can_repeat(src, dst));
  89. aclTensor* acl_src = ggml_cann_create_tensor(src);
  90. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  91. int64_t repeatsArray[] = {dst->ne[3] / src->ne[3], dst->ne[2] / src->ne[2],
  92. dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]};
  93. aclnn_repeat(ctx, acl_src, acl_dst, repeatsArray);
  94. ACL_CHECK(aclDestroyTensor(acl_src));
  95. ACL_CHECK(aclDestroyTensor(acl_dst));
  96. }
  97. /**
  98. * @brief Adds two tensors element-wise and stores the result in a destination
  99. * tensor.
  100. *
  101. * This function performs the operation:
  102. * \f[
  103. * dst = acl\_src0 + alpha \times acl\_src1
  104. * \f]
  105. * where alpha is a scalar value and defaults to 1.0f.
  106. *
  107. * @param ctx The context for the CANN backend operations.
  108. * @param acl_src0 The first source tensor.
  109. * @param acl_src1 The second source tensor.
  110. * @param acl_dst The destination tensor where the result will be stored.
  111. */
  112. static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
  113. aclTensor* acl_src1, aclTensor* acl_dst) {
  114. aclScalar* alpha = nullptr;
  115. float alphaValue = 1.0f;
  116. alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
  117. uint64_t workspaceSize = 0;
  118. aclOpExecutor* executor;
  119. void* workspaceAddr = nullptr;
  120. ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst,
  121. &workspaceSize, &executor));
  122. if (workspaceSize > 0) {
  123. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  124. workspaceAddr = workspace_allocator.get();
  125. }
  126. ACL_CHECK(aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
  127. ACL_CHECK(aclDestroyScalar(alpha));
  128. }
  129. void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  130. ggml_tensor* src0 = dst->src[0];
  131. ggml_tensor* src1 = dst->src[1];
  132. GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
  133. aclTensor* acl_src0;
  134. aclTensor* acl_src1;
  135. aclTensor* acl_dst;
  136. // Need bcast
  137. if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) {
  138. BCAST_SHAPE(src0, src1)
  139. acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0));
  140. acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1));
  141. acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0));
  142. } else {
  143. acl_src0 = ggml_cann_create_tensor(src0);
  144. acl_src1 = ggml_cann_create_tensor(src1);
  145. acl_dst = ggml_cann_create_tensor(dst);
  146. }
  147. aclnn_add(ctx, acl_src0, acl_src1, acl_dst);
  148. ACL_CHECK(aclDestroyTensor(acl_src0));
  149. ACL_CHECK(aclDestroyTensor(acl_src1));
  150. ACL_CHECK(aclDestroyTensor(acl_dst));
  151. }
  152. void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  153. ggml_tensor* src = dst->src[0];
  154. GGML_ASSERT(src->type == GGML_TYPE_F32);
  155. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  156. aclTensor* acl_src = ggml_cann_create_tensor(src);
  157. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  158. float negative_slope;
  159. memcpy(&negative_slope, dst->op_params, sizeof(float));
  160. aclScalar* acl_negative_slope =
  161. aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT);
  162. uint64_t workspaceSize = 0;
  163. aclOpExecutor* executor;
  164. void* workspaceAddr = nullptr;
  165. ACL_CHECK(aclnnLeakyReluGetWorkspaceSize(
  166. acl_src, acl_negative_slope, acl_dst, &workspaceSize, &executor));
  167. if (workspaceSize > 0) {
  168. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  169. workspaceAddr = workspace_allocator.get();
  170. }
  171. ACL_CHECK(
  172. aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, ctx.stream()));
  173. ACL_CHECK(aclDestroyScalar(acl_negative_slope));
  174. ACL_CHECK(aclDestroyTensor(acl_src));
  175. ACL_CHECK(aclDestroyTensor(acl_dst));
  176. }
  177. /**
  178. * @brief Concatenates a list of tensors along a specified dimension and stores
  179. * the result in a destination tensor.
  180. *
  181. * @param ctx The context for the CANN backend operations.
  182. * @param tensorList The list of tensors to be concatenated.
  183. * @param acl_dst The destination tensor where the concatenated result will be
  184. * stored.
  185. * @param concat_dim The dimension along which the tensors will be concatenated.
  186. */
  187. static void aclnn_concat(ggml_backend_cann_context& ctx,
  188. aclTensorList* tensorList, aclTensor* acl_dst,
  189. int64_t concat_dim) {
  190. uint64_t workspaceSize = 0;
  191. aclOpExecutor* executor;
  192. void* workspaceAddr = nullptr;
  193. ACL_CHECK(aclnnCatGetWorkspaceSize(tensorList, concat_dim, acl_dst,
  194. &workspaceSize, &executor));
  195. if (workspaceSize > 0) {
  196. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  197. workspaceAddr = workspace_allocator.get();
  198. }
  199. ACL_CHECK(aclnnCat(workspaceAddr, workspaceSize, executor, ctx.stream()));
  200. }
  201. void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  202. ggml_tensor* src0 = dst->src[0];
  203. ggml_tensor* src1 = dst->src[1];
  204. aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
  205. aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
  206. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  207. int64_t concat_dim = 1;
  208. aclTensor* tensors[] = {acl_src0, acl_src1};
  209. aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
  210. aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
  211. ACL_CHECK(aclDestroyTensorList(tensorList));
  212. ACL_CHECK(aclDestroyTensor(acl_dst));
  213. }
  214. /**
  215. * @brief Creates a tensor with values starting from `start`, incremented by
  216. * `step`, and ending before `stop`.
  217. *
  218. * This function performs the operation:
  219. * \f[
  220. * \text {out }_{i+1}=\text {out }_i+\text {step}
  221. * \f]
  222. * the range is [start, stop).
  223. *
  224. * @param ctx The context for the CANN backend operations.
  225. * @param acl_dst The destination tensor where the values will be stored.
  226. * @param start The starting value of the range.
  227. * @param stop The ending value of the range (exclusive).
  228. * @param step The step size between consecutive values.
  229. * @param n_elements The number of elements in the destination tensor.
  230. */
  231. static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst,
  232. float start, float stop, float step,
  233. int64_t n_elements) {
  234. int64_t steps = (int64_t)std::ceil((stop - start) / step);
  235. GGML_ASSERT(n_elements == steps);
  236. uint64_t workspaceSize = 0;
  237. aclOpExecutor* executor;
  238. void* workspaceAddr = nullptr;
  239. aclScalar* acl_start = aclCreateScalar(&start, aclDataType::ACL_FLOAT);
  240. aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT);
  241. aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT);
  242. ACL_CHECK(aclnnArangeGetWorkspaceSize(acl_start, acl_end, acl_step, acl_dst,
  243. &workspaceSize, &executor));
  244. if (workspaceSize > 0) {
  245. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  246. workspaceAddr = workspace_allocator.get();
  247. }
  248. ACL_CHECK(
  249. aclnnArange(workspaceAddr, workspaceSize, executor, ctx.stream()));
  250. ACL_CHECK(aclDestroyScalar(acl_start));
  251. ACL_CHECK(aclDestroyScalar(acl_end));
  252. ACL_CHECK(aclDestroyScalar(acl_step));
  253. }
  254. void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  255. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  256. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  257. int64_t n_elements = ggml_nelements(dst);
  258. float start;
  259. float stop;
  260. float step;
  261. memcpy(&start, (float*)dst->op_params + 0, sizeof(float));
  262. memcpy(&stop, (float*)dst->op_params + 1, sizeof(float));
  263. memcpy(&step, (float*)dst->op_params + 2, sizeof(float));
  264. aclnn_arange(ctx, acl_dst, start, stop, step, n_elements);
  265. ACL_CHECK(aclDestroyTensor(acl_dst));
  266. }
  267. void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  268. dst->src[1] = dst->src[0];
  269. ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
  270. }
  271. void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  272. ggml_tensor* src = dst->src[0];
  273. GGML_ASSERT(src->type == GGML_TYPE_F32);
  274. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  275. float min;
  276. float max;
  277. memcpy(&min, dst->op_params, sizeof(float));
  278. memcpy(&max, (float*)dst->op_params + 1, sizeof(float));
  279. aclTensor* acl_src = ggml_cann_create_tensor(src);
  280. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  281. aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT);
  282. aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT);
  283. uint64_t workspaceSize = 0;
  284. aclOpExecutor* executor;
  285. void* workspaceAddr = nullptr;
  286. ACL_CHECK(aclnnClampGetWorkspaceSize(acl_src, acl_min, acl_max, acl_dst,
  287. &workspaceSize, &executor));
  288. if (workspaceSize > 0) {
  289. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  290. workspaceAddr = workspace_allocator.get();
  291. }
  292. ACL_CHECK(aclnnClamp(workspaceAddr, workspaceSize, executor, ctx.stream()));
  293. ACL_CHECK(aclDestroyScalar(acl_min));
  294. ACL_CHECK(aclDestroyScalar(acl_max));
  295. ACL_CHECK(aclDestroyTensor(acl_src));
  296. ACL_CHECK(aclDestroyTensor(acl_dst));
  297. }
  298. void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  299. ggml_tensor* src = dst->src[0];
  300. // scale factor
  301. float v;
  302. memcpy(&v, dst->op_params, sizeof(float));
  303. aclScalar* scale = aclCreateScalar(&v, aclDataType::ACL_FLOAT);
  304. aclTensor* acl_src = ggml_cann_create_tensor(src);
  305. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  306. uint64_t workspaceSize = 0;
  307. aclOpExecutor* executor;
  308. void* workspaceAddr = nullptr;
  309. ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, scale, acl_dst, &workspaceSize,
  310. &executor));
  311. if (workspaceSize > 0) {
  312. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  313. workspaceAddr = workspace_allocator.get();
  314. }
  315. ACL_CHECK(aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
  316. ACL_CHECK(aclDestroyScalar(scale));
  317. ACL_CHECK(aclDestroyTensor(acl_src));
  318. ACL_CHECK(aclDestroyTensor(acl_dst));
  319. }
  320. void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  321. ggml_tensor* src = dst->src[0];
  322. enum ggml_sort_order order = (enum ggml_sort_order)dst->op_params[0];
  323. aclTensor* acl_src = ggml_cann_create_tensor(src);
  324. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  325. ggml_cann_pool_alloc temp_buffer_allocator(
  326. ctx.pool(), ggml_nelements(dst) * sizeof(int64_t));
  327. void* buffer = temp_buffer_allocator.get();
  328. aclTensor* tmp_tensor =
  329. ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type),
  330. dst->ne, dst->nb, GGML_MAX_DIMS);
  331. uint64_t workspaceSize = 0;
  332. aclOpExecutor* executor;
  333. void* workspaceAddr = nullptr;
  334. ACL_CHECK(aclnnArgsortGetWorkspaceSize(
  335. acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), tmp_tensor,
  336. &workspaceSize, &executor));
  337. if (workspaceSize > 0) {
  338. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  339. workspaceAddr = workspace_allocator.get();
  340. }
  341. ACL_CHECK(
  342. aclnnArgsort(workspaceAddr, workspaceSize, executor, ctx.stream()));
  343. workspaceSize = 0;
  344. ACL_CHECK(aclnnCastGetWorkspaceSize(tmp_tensor,
  345. ggml_cann_type_mapping(dst->type),
  346. acl_dst, &workspaceSize, &executor));
  347. if (workspaceSize > 0) {
  348. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  349. workspaceAddr = workspace_allocator.get();
  350. }
  351. ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream()));
  352. ACL_CHECK(aclDestroyTensor(acl_src));
  353. ACL_CHECK(aclDestroyTensor(tmp_tensor));
  354. ACL_CHECK(aclDestroyTensor(acl_dst));
  355. }
  356. void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  357. ggml_tensor* src = dst->src[0];
  358. aclTensor* acl_src = ggml_cann_create_tensor(src);
  359. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  360. float eps;
  361. memcpy(&eps, dst->op_params, sizeof(float));
  362. uint64_t workspaceSize = 0;
  363. aclOpExecutor* executor;
  364. void* workspaceAddr = nullptr;
  365. std::vector<int64_t> normData = {dst->ne[0]};
  366. aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size());
  367. ACL_CHECK(aclnnLayerNormGetWorkspaceSize(acl_src, norm, nullptr, nullptr,
  368. eps, acl_dst, nullptr, nullptr,
  369. &workspaceSize, &executor));
  370. if (workspaceSize > 0) {
  371. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  372. workspaceAddr = workspace_allocator.get();
  373. }
  374. ACL_CHECK(
  375. aclnnLayerNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
  376. ACL_CHECK(aclDestroyIntArray(norm));
  377. ACL_CHECK(aclDestroyTensor(acl_src));
  378. ACL_CHECK(aclDestroyTensor(acl_dst));
  379. }
  380. void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  381. ggml_tensor* src = dst->src[0];
  382. aclTensor* acl_src = ggml_cann_create_tensor(src);
  383. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  384. int n_groups = dst->op_params[0];
  385. float eps;
  386. memcpy(&eps, dst->op_params + 1, sizeof(float));
  387. uint64_t workspaceSize = 0;
  388. aclOpExecutor* executor;
  389. void* workspaceAddr = nullptr;
  390. int64_t N = src->ne[3];
  391. int64_t C = src->ne[2];
  392. int64_t HxW = src->ne[1] * src->ne[0];
  393. size_t type_size = ggml_type_size(src->type);
  394. int64_t ne[] = {n_groups, N};
  395. size_t nb[] = {type_size, type_size * n_groups};
  396. size_t n_bytes = N * n_groups;
  397. ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes * 2);
  398. void* buffer = temp_buffer_allocator.get();
  399. aclTensor* acl_mean_out = ggml_cann_create_tensor(
  400. buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
  401. aclTensor* acl_rstd_out = ggml_cann_create_tensor(
  402. (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
  403. ACL_CHECK(aclnnGroupNormGetWorkspaceSize(
  404. acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst,
  405. acl_mean_out, acl_rstd_out, &workspaceSize, &executor));
  406. if (workspaceSize > 0) {
  407. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  408. workspaceAddr = workspace_allocator.get();
  409. }
  410. ACL_CHECK(
  411. aclnnGroupNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
  412. ACL_CHECK(aclDestroyTensor(acl_src));
  413. ACL_CHECK(aclDestroyTensor(acl_dst));
  414. ACL_CHECK(aclDestroyTensor(acl_mean_out));
  415. ACL_CHECK(aclDestroyTensor(acl_rstd_out));
  416. }
  417. void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  418. ggml_tensor* src0 = dst->src[0];
  419. ggml_tensor* src1 = dst->src[1];
  420. size_t nb1 = ((int32_t*)dst->op_params)[0];
  421. size_t nb2 = ((int32_t*)dst->op_params)[1];
  422. size_t nb3 = ((int32_t*)dst->op_params)[2];
  423. size_t offset = ((int32_t*)dst->op_params)[3];
  424. bool inplace = (bool)((int32_t*)dst->op_params)[4];
  425. size_t param_nb[] = {ggml_element_size(src0), nb1, nb2, nb3};
  426. aclTensor* acl_dst = ggml_cann_create_tensor(
  427. dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
  428. aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
  429. aclScalar* alpha = nullptr;
  430. float alphaValue = 1.0f;
  431. alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
  432. uint64_t workspaceSize = 0;
  433. aclOpExecutor* executor;
  434. void* workspaceAddr = nullptr;
  435. if (!inplace) {
  436. size_t cpy_size = ggml_nbytes(dst);
  437. ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size,
  438. ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
  439. aclTensor* acl_src0 = ggml_cann_create_tensor(
  440. src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
  441. ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst,
  442. &workspaceSize, &executor));
  443. if (workspaceSize > 0) {
  444. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  445. workspaceAddr = workspace_allocator.get();
  446. }
  447. ACL_CHECK(
  448. aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
  449. ACL_CHECK(aclDestroyTensor(acl_src0));
  450. } else {
  451. ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src1, alpha,
  452. &workspaceSize, &executor));
  453. if (workspaceSize > 0) {
  454. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  455. workspaceAddr = workspace_allocator.get();
  456. }
  457. ACL_CHECK(aclnnInplaceAdd(workspaceAddr, workspaceSize, executor,
  458. ctx.stream()));
  459. }
  460. ACL_CHECK(aclDestroyTensor(acl_src1));
  461. ACL_CHECK(aclDestroyTensor(acl_dst));
  462. }
  463. void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  464. ggml_tensor* src = dst->src[0];
  465. aclTensor* acl_src = ggml_cann_create_tensor(src);
  466. GGML_ASSERT(dst->ne[0] == 1);
  467. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  468. int64_t reduce_dims_host[] = {3};
  469. aclIntArray* reduce_dims = aclCreateIntArray(reduce_dims_host, 1);
  470. uint64_t workspaceSize = 0;
  471. aclOpExecutor* executor;
  472. void* workspaceAddr = nullptr;
  473. ACL_CHECK(aclnnReduceSumGetWorkspaceSize(
  474. acl_src, reduce_dims, true, ggml_cann_type_mapping(src->type), acl_dst,
  475. &workspaceSize, &executor));
  476. if (workspaceSize > 0) {
  477. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  478. workspaceAddr = workspace_allocator.get();
  479. }
  480. ACL_CHECK(
  481. aclnnReduceSum(workspaceAddr, workspaceSize, executor, ctx.stream()));
  482. ACL_CHECK(aclDestroyTensor(acl_src));
  483. ACL_CHECK(aclDestroyTensor(acl_dst));
  484. }
  485. void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
  486. ggml_tensor* dst) {
  487. ggml_tensor* src = dst->src[0];
  488. aclTensor* acl_src =
  489. ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
  490. aclTensor* acl_dst =
  491. ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
  492. std::vector<int64_t> output_size{dst->ne[1], dst->ne[0]};
  493. auto output_size_array = aclCreateIntArray(output_size.data(), 2);
  494. uint64_t workspaceSize = 0;
  495. aclOpExecutor* executor;
  496. void* workspaceAddr = nullptr;
  497. ACL_CHECK(aclnnUpsampleNearest2dGetWorkspaceSize(
  498. acl_src, output_size_array, acl_dst, &workspaceSize, &executor));
  499. if (workspaceSize > 0) {
  500. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  501. workspaceAddr = workspace_allocator.get();
  502. }
  503. ACL_CHECK(aclnnUpsampleNearest2d(workspaceAddr, workspaceSize, executor,
  504. ctx.stream()));
  505. ACL_CHECK(aclDestroyIntArray(output_size_array));
  506. ACL_CHECK(aclDestroyTensor(acl_src));
  507. ACL_CHECK(aclDestroyTensor(acl_dst));
  508. }
  509. /**
  510. * @brief Pads a tensor with a specified value along each dimension.
  511. *
  512. * This function performs padding of the source tensor `acl_src` and stores the
  513. * result in the destination tensor `acl_dst`. The padding values for each
  514. * dimension are specified in the `paddings` array.
  515. *
  516. * @param ctx The context for the CANN backend operations.
  517. * @param acl_src The source tensor to be padded.
  518. * @param acl_dst The destination tensor where the padded result will be stored.
  519. * @param paddings An array specifying the padding values for each dimension.
  520. * The size of the array should be twice the number of dimensions of the tensor.
  521. * @param value The value to be used for padding. The default value is 0.0.
  522. */
  523. static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  524. aclTensor* acl_dst, int64_t* paddings,
  525. float value = 0.0f) {
  526. aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2);
  527. aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
  528. uint64_t workspaceSize = 0;
  529. aclOpExecutor* executor;
  530. void* workspaceAddr = nullptr;
  531. ACL_CHECK(aclnnConstantPadNdGetWorkspaceSize(
  532. acl_src, acl_pad, acl_value, acl_dst, &workspaceSize, &executor));
  533. if (workspaceSize > 0) {
  534. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  535. workspaceAddr = workspace_allocator.get();
  536. }
  537. ACL_CHECK(aclnnConstantPadNd(workspaceAddr, workspaceSize, executor,
  538. ctx.stream()));
  539. ACL_CHECK(aclDestroyIntArray(acl_pad));
  540. ACL_CHECK(aclDestroyScalar(acl_value));
  541. }
  542. void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  543. ggml_tensor* src = dst->src[0];
  544. aclTensor* acl_src = ggml_cann_create_tensor(src);
  545. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  546. // padding: value in the array means how much distance will be padding.
  547. // the position of elements in the array means which dirction to padding,
  548. // each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,
  549. // dim2.front, dim2.behind, dim3.front, dim3.behind]
  550. int64_t paddings[] = {
  551. 0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
  552. 0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
  553. aclnn_pad(ctx, acl_src, acl_dst, paddings);
  554. ACL_CHECK(aclDestroyTensor(acl_dst));
  555. ACL_CHECK(aclDestroyTensor(acl_src));
  556. }
  557. /**
  558. * @brief Performs 2D average pooling on the input tensor and stores the result
  559. * in the destination tensor.
  560. *
  561. * This function performs average pooling on the source tensor and stores the
  562. * result in the destination tensor. The pooling parameters (kernel size,
  563. * strides, padding) are specified in the `op_params` of the destination tensor.
  564. *
  565. * @param ctx The context for the CANN backend operations.
  566. * @param dst The destination tensor where the result will be stored. The source
  567. * tensor is referenced by `dst->src[0]`.
  568. */
  569. static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx,
  570. ggml_tensor* dst) {
  571. ggml_tensor* src = dst->src[0];
  572. GGML_ASSERT(src->type == GGML_TYPE_F32);
  573. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  574. aclTensor* acl_src =
  575. ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
  576. aclTensor* acl_dst =
  577. ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
  578. const int32_t* opts = (const int32_t*)dst->op_params;
  579. const int k0 = opts[1];
  580. const int k1 = opts[2];
  581. const int s0 = opts[3];
  582. const int s1 = opts[4];
  583. const int p0 = opts[5];
  584. const int p1 = opts[6];
  585. std::vector<int64_t> kernel_dims = {k1, k0};
  586. std::vector<int64_t> stride_dims = {s1, s0};
  587. std::vector<int64_t> padding_avg_dims = {p1, p0}; // (padH, padW)
  588. auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2);
  589. auto* strides = aclCreateIntArray(stride_dims.data(), 2);
  590. auto* paddings_avg = aclCreateIntArray(padding_avg_dims.data(), 2);
  591. bool ceil_mode = false;
  592. bool count_include_pad = true;
  593. int64_t divisor_override = 0;
  594. int8_t cube_math_type = 0;
  595. uint64_t workspaceSize = 0;
  596. aclOpExecutor* executor;
  597. void* workspaceAddr = nullptr;
  598. ACL_CHECK(aclnnAvgPool2dGetWorkspaceSize(
  599. acl_src, kernel_size, strides, paddings_avg, ceil_mode,
  600. count_include_pad, divisor_override, cube_math_type, acl_dst,
  601. &workspaceSize, &executor));
  602. if (workspaceSize > 0) {
  603. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  604. workspaceAddr = workspace_allocator.get();
  605. }
  606. ACL_CHECK(
  607. aclnnAvgPool2d(workspaceAddr, workspaceSize, executor, ctx.stream()));
  608. ACL_CHECK(aclDestroyTensor(acl_src));
  609. ACL_CHECK(aclDestroyTensor(acl_dst));
  610. ACL_CHECK(aclDestroyIntArray(kernel_size));
  611. ACL_CHECK(aclDestroyIntArray(strides));
  612. ACL_CHECK(aclDestroyIntArray(paddings_avg));
  613. }
  614. /**
  615. * @brief Performs 2D max pooling on the input tensor and stores the result in
  616. * the destination tensor.
  617. *
  618. * This function performs max pooling on the source tensor and stores the result
  619. * in the destination tensor. The pooling parameters (kernel size, strides,
  620. * padding) are specified in the `op_params` of the destination tensor.
  621. *
  622. * @param ctx The context for the CANN backend operations.
  623. * @param dst The destination tensor where the result will be stored. The source
  624. * tensor is referenced by `dst->src[0]`.
  625. */
  626. static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx,
  627. ggml_tensor* dst) {
  628. ggml_tensor* src = dst->src[0];
  629. GGML_ASSERT(src->type == GGML_TYPE_F32);
  630. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  631. aclTensor* acl_src =
  632. ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
  633. aclTensor* acl_dst =
  634. ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
  635. const int32_t* opts = (const int32_t*)dst->op_params;
  636. const int k0 = opts[1];
  637. const int k1 = opts[2];
  638. const int s0 = opts[3];
  639. const int s1 = opts[4];
  640. const int p0 = opts[5];
  641. const int p1 = opts[6];
  642. int64_t temp_ne[] = {src->ne[0] + p0 * 2, src->ne[1] + p1 * 2, src->ne[2],
  643. src->ne[3]};
  644. size_t temp_nb[GGML_MAX_DIMS];
  645. temp_nb[0] = ggml_element_size(src);
  646. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  647. temp_nb[i] = temp_nb[i - 1] * temp_ne[i - 1];
  648. }
  649. ggml_cann_pool_alloc temp_buffer_allocator(
  650. ctx.pool(), ggml_nbytes(src) + p0 * 2 + p1 * 2 * src->nb[1]);
  651. void* buffer = temp_buffer_allocator.get();
  652. aclTensor* tmp_tensor = ggml_cann_create_tensor(
  653. buffer, ACL_FLOAT, ggml_element_size(src), temp_ne, temp_nb,
  654. GGML_MAX_DIMS, ACL_FORMAT_NCHW);
  655. // pad: see padding in ggml_cann_pad()
  656. int64_t paddings[] = {p0, p0, p1, p1, 0, 0, 0, 0};
  657. float value = -FLT_MAX;
  658. aclnn_pad(ctx, acl_src, tmp_tensor, paddings, value);
  659. // max_pool
  660. std::vector<int64_t> kernel_dims = {k1, k0};
  661. std::vector<int64_t> stride_dims = {s1, s0};
  662. // padding_max_dims: [dim0_start, dim0_end, dim1_start, dim1_end]
  663. std::vector<int64_t> padding_max_dims = {0, 0, 0, 0};
  664. std::vector<int64_t> dilation_size = {1, 1};
  665. auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2);
  666. auto* strides = aclCreateIntArray(stride_dims.data(), 2);
  667. auto* paddings_max = aclCreateIntArray(padding_max_dims.data(), 4);
  668. auto* dilations = aclCreateIntArray(dilation_size.data(), 2);
  669. bool ceil_mode = false;
  670. int64_t auto_pads = 0;
  671. uint64_t workspaceSize = 0;
  672. aclOpExecutor* executor;
  673. void* workspaceAddr = nullptr;
  674. ACL_CHECK(aclnnMaxPoolGetWorkspaceSize(
  675. tmp_tensor, kernel_size, strides, auto_pads, paddings_max, dilations,
  676. ceil_mode, acl_dst, &workspaceSize, &executor));
  677. if (workspaceSize > 0) {
  678. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  679. workspaceAddr = workspace_allocator.get();
  680. }
  681. ACL_CHECK(
  682. aclnnMaxPool(workspaceAddr, workspaceSize, executor, ctx.stream()));
  683. ACL_CHECK(aclDestroyTensor(acl_src));
  684. ACL_CHECK(aclDestroyTensor(acl_dst));
  685. ACL_CHECK(aclDestroyTensor(tmp_tensor));
  686. ACL_CHECK(aclDestroyIntArray(kernel_size));
  687. ACL_CHECK(aclDestroyIntArray(strides));
  688. ACL_CHECK(aclDestroyIntArray(paddings_max));
  689. ACL_CHECK(aclDestroyIntArray(dilations));
  690. }
  691. void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  692. const int32_t* opts = (const int32_t*)dst->op_params;
  693. enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
  694. switch (op) {
  695. case GGML_OP_POOL_AVG:
  696. ggml_cann_avg_pool2d(ctx, dst);
  697. break;
  698. case GGML_OP_POOL_MAX:
  699. ggml_cann_max_pool2d(ctx, dst);
  700. break;
  701. case GGML_OP_POOL_COUNT:
  702. GGML_ABORT("fatal error");
  703. break;
  704. }
  705. }
  706. /**
  707. * @brief Copies data from the source tensor to the destination tensor.
  708. *
  709. * This function copies data from the source tensor `acl_src` to the destination
  710. * tensor `acl_dst`.
  711. *
  712. * @param ctx The context for the CANN backend operations.
  713. * @param acl_src The source tensor from which data will be copied.
  714. * @param acl_dst The destination tensor where the data will be copied to.
  715. */
  716. static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  717. aclTensor* acl_dst) {
  718. uint64_t workspaceSize = 0;
  719. aclOpExecutor* executor;
  720. void* workspaceAddr = nullptr;
  721. ACL_CHECK(aclnnInplaceCopyGetWorkspaceSize(acl_dst, acl_src, &workspaceSize,
  722. &executor));
  723. if (workspaceSize > 0) {
  724. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  725. workspaceAddr = workspace_allocator.get();
  726. }
  727. ACL_CHECK(
  728. aclnnInplaceCopy(workspaceAddr, workspaceSize, executor, ctx.stream()));
  729. }
  730. void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  731. ggml_tensor* src = dst->src[0];
  732. aclTensor* acl_src = ggml_cann_create_tensor(src);
  733. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  734. ggml_cann_pool_alloc src_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
  735. ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
  736. src->extra = src_extra_allocator.get();
  737. dst->extra = dst_extra_allocator.get();
  738. ACL_CHECK(aclrtMemcpyAsync(src->extra, sizeof(ggml_tensor), src,
  739. sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
  740. ctx.stream()));
  741. ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst,
  742. sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
  743. ctx.stream()));
  744. if ((dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32) &&
  745. ggml_are_same_shape(src, dst)) {
  746. cann_copy(ctx, acl_src, acl_dst);
  747. ACL_CHECK(aclDestroyTensor(acl_src));
  748. ACL_CHECK(aclDestroyTensor(acl_dst));
  749. return;
  750. }
  751. // TODO: simplify
  752. if (src->type == GGML_TYPE_F16) {
  753. if (dst->type == GGML_TYPE_Q8_0) {
  754. aclrtlaunch_ascendc_quantize_f16_q8_0(
  755. 24, ctx.stream(), src->data, dst->data,
  756. ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
  757. ((ggml_tensor*)dst->extra)->ne);
  758. return;
  759. }
  760. if (dst->type == GGML_TYPE_Q4_0) {
  761. aclrtlaunch_ascendc_quantize_f16_to_q4_0(
  762. 24, ctx.stream(), src->data, dst->data,
  763. ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
  764. ((ggml_tensor*)dst->extra)->ne);
  765. return;
  766. }
  767. if (dst->type == GGML_TYPE_F16) {
  768. if (ggml_are_same_shape(src, dst)) {
  769. cann_copy(ctx, acl_src, acl_dst);
  770. ACL_CHECK(aclDestroyTensor(acl_src));
  771. ACL_CHECK(aclDestroyTensor(acl_dst));
  772. return;
  773. }
  774. if (ggml_is_contiguous(dst)) {
  775. const size_t src_type_size = ggml_type_size(src->type);
  776. if (src->nb[0] == src_type_size) {
  777. // src0 is contigous on first dimension, copy by rows
  778. int64_t rows_num = ggml_nrows(src);
  779. aclrtlaunch_ascendc_dup_by_rows_fp16(
  780. rows_num, ctx.stream(), src->data, dst->data,
  781. ((ggml_tensor*)src->extra)->ne,
  782. ((ggml_tensor*)src->extra)->nb,
  783. ((ggml_tensor*)dst->extra)->ne,
  784. ((ggml_tensor*)dst->extra)->nb);
  785. return;
  786. }
  787. GGML_ABORT("fatal error");
  788. }
  789. GGML_ABORT("fatal error");
  790. }
  791. if (dst->type == GGML_TYPE_F32) {
  792. if (ggml_are_same_shape(src, dst)) {
  793. cann_copy(ctx, acl_src, acl_dst);
  794. ACL_CHECK(aclDestroyTensor(acl_src));
  795. ACL_CHECK(aclDestroyTensor(acl_dst));
  796. return;
  797. }
  798. if (ggml_is_contiguous(dst)) {
  799. const size_t src_type_size = ggml_type_size(src->type);
  800. if (src->nb[0] == src_type_size) {
  801. // src0 is contigous on first dimension, copy by rows
  802. int64_t rows_num = ggml_nrows(src);
  803. aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32(
  804. rows_num, ctx.stream(), src->data, dst->data,
  805. ((ggml_tensor*)src->extra)->ne,
  806. ((ggml_tensor*)src->extra)->nb,
  807. ((ggml_tensor*)dst->extra)->ne,
  808. ((ggml_tensor*)dst->extra)->nb);
  809. return;
  810. }
  811. GGML_ABORT("fatal error");
  812. }
  813. GGML_ABORT("fatal error");
  814. }
  815. // TODO
  816. GGML_ABORT("fatal error");
  817. } else if (src->type == GGML_TYPE_F32) {
  818. // TODO: if (src0->type == dst->type && ne00 == ne0 && nb00 == type_size
  819. // && nb0 == type_size)
  820. if (dst->type == GGML_TYPE_Q8_0) {
  821. aclrtlaunch_ascendc_quantize_f32_q8_0(
  822. 24, ctx.stream(), src->data, dst->data,
  823. ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
  824. ((ggml_tensor*)dst->extra)->ne);
  825. return;
  826. }
  827. if (dst->type == GGML_TYPE_Q4_0) {
  828. aclrtlaunch_ascendc_quantize_f32_to_q4_0(
  829. 24, ctx.stream(), src->data, dst->data,
  830. ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
  831. ((ggml_tensor*)dst->extra)->ne);
  832. return;
  833. }
  834. if (dst->type == GGML_TYPE_F32) {
  835. if (ggml_are_same_shape(src, dst)) {
  836. cann_copy(ctx, acl_src, acl_dst);
  837. ACL_CHECK(aclDestroyTensor(acl_src));
  838. ACL_CHECK(aclDestroyTensor(acl_dst));
  839. return;
  840. }
  841. if (ggml_is_contiguous(dst)) {
  842. const size_t src_type_size = ggml_type_size(src->type);
  843. if (src->nb[0] == src_type_size) {
  844. // src0 is contigous on first dimension, copy by rows
  845. int64_t rows_num = ggml_nrows(src);
  846. aclrtlaunch_ascendc_dup_by_rows_fp32(
  847. rows_num, ctx.stream(), src->data, dst->data,
  848. ((ggml_tensor*)src->extra)->ne,
  849. ((ggml_tensor*)src->extra)->nb,
  850. ((ggml_tensor*)dst->extra)->ne,
  851. ((ggml_tensor*)dst->extra)->nb);
  852. return;
  853. }
  854. GGML_ABORT("fatal error");
  855. } else {
  856. // TODO: dst not contiguous
  857. GGML_ABORT("fatal error");
  858. }
  859. }
  860. if (dst->type == GGML_TYPE_F16) {
  861. if (ggml_are_same_shape(src, dst)) {
  862. cann_copy(ctx, acl_src, acl_dst);
  863. ACL_CHECK(aclDestroyTensor(acl_src));
  864. ACL_CHECK(aclDestroyTensor(acl_dst));
  865. return;
  866. }
  867. if (ggml_is_contiguous(dst)) {
  868. const size_t src_type_size = ggml_type_size(src->type);
  869. if (src->nb[0] == src_type_size) {
  870. // src0 is contigous on first dimension, copy by rows
  871. int64_t rows_num = ggml_nrows(src);
  872. aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16(
  873. rows_num, ctx.stream(), src->data, dst->data,
  874. ((ggml_tensor*)src->extra)->ne,
  875. ((ggml_tensor*)src->extra)->nb,
  876. ((ggml_tensor*)dst->extra)->ne,
  877. ((ggml_tensor*)dst->extra)->nb);
  878. return;
  879. }
  880. GGML_ABORT("fatal error");
  881. }
  882. }
  883. // TODO
  884. GGML_ABORT("fatal error");
  885. } else {
  886. if (ggml_are_same_shape(src, dst)) {
  887. cann_copy(ctx, acl_src, acl_dst);
  888. ACL_CHECK(aclDestroyTensor(acl_src));
  889. ACL_CHECK(aclDestroyTensor(acl_dst));
  890. return;
  891. }
  892. GGML_ABORT("fatal error");
  893. }
  894. }
  895. #ifdef __cplusplus
  896. extern "C" {
  897. #endif
  898. aclnnStatus aclnnRmsNormGetWorkspaceSize(const aclTensor* x,
  899. const aclTensor* gamma, double epsilon,
  900. const aclTensor* yOut,
  901. const aclTensor* rstdOout,
  902. uint64_t* workspaceSize,
  903. aclOpExecutor** executor);
  904. aclnnStatus aclnnRmsNorm(void* workspace, uint64_t workspaceSize,
  905. aclOpExecutor* executor, aclrtStream stream);
  906. #ifdef __cplusplus
  907. }
  908. #endif
  909. /**
  910. * @brief Creates an ACL tensor initialized with zeros using a provided buffer.
  911. *
  912. * This function initializes a tensor with zeros using the specified buffer and
  913. * tensor parameters.
  914. *
  915. * @param ctx The context for the CANN backend operations.
  916. * @param buffer The buffer to be used for the tensor data.
  917. * @param n_bytes The size of the buffer in bytes.
  918. * @param ne An array specifying the extents (sizes) of each dimension of the
  919. * tensor.
  920. * @param dims The number of dimensions of the tensor.
  921. * @param type The data type of the tensor.
  922. * @param type_size The size of each element in the tensor data type.
  923. * @return An ACL tensor initialized with zeros.
  924. */
  925. static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
  926. size_t n_bytes, int64_t* ne, int64_t dims,
  927. aclDataType type, size_t type_size) {
  928. size_t nb[GGML_MAX_DIMS];
  929. nb[0] = type_size;
  930. for (int i = 1; i < dims; i++) {
  931. nb[i] = nb[i - 1] * ne[i - 1];
  932. }
  933. ACL_CHECK(aclrtMemsetAsync(buffer, n_bytes, 0, n_bytes, ctx.stream()));
  934. aclTensor* zero =
  935. ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
  936. return zero;
  937. }
  938. /**
  939. * @brief Creates an ACL tensor initialized with ones using a provided buffer.
  940. *
  941. * This function initializes a tensor with ones using the specified buffer and
  942. * tensor parameters.
  943. *
  944. * @param ctx The context for the CANN backend operations.
  945. * @param buffer The buffer to be used for the tensor data.
  946. * @param n_bytes The size of the buffer in bytes.
  947. * @param ne An array specifying the extents (sizes) of each dimension of the
  948. * tensor.
  949. * @param dims The number of dimensions of the tensor.
  950. * @param type The data type of the tensor.
  951. * @param type_size The size of each element in the tensor data type.
  952. * @param value The value to be used for initializing the tensor (default
  953. * is 1.0).
  954. * @return An ACL tensor initialized with ones.
  955. */
  956. static aclTensor* aclnn_ones(ggml_backend_cann_context& ctx, void* buffer,
  957. size_t n_bytes, int64_t* ne, int64_t dims,
  958. aclDataType type, size_t type_size,
  959. float value = 1.0f) {
  960. aclTensor* acl_tensor =
  961. aclnn_zero(ctx, buffer, n_bytes, ne, dims, type, type_size);
  962. float alpha_host = 1.0f;
  963. aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT);
  964. aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
  965. uint64_t workspaceSize = 0;
  966. aclOpExecutor* executor;
  967. void* workspaceAddr = nullptr;
  968. ACL_CHECK(aclnnInplaceAddsGetWorkspaceSize(acl_tensor, other, alpha,
  969. &workspaceSize, &executor));
  970. if (workspaceSize > 0) {
  971. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  972. workspaceAddr = workspace_allocator.get();
  973. }
  974. ACL_CHECK(
  975. aclnnInplaceAdds(workspaceAddr, workspaceSize, executor, ctx.stream()));
  976. return acl_tensor;
  977. }
  978. void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  979. ggml_tensor* src = dst->src[0];
  980. aclTensor* acl_src = ggml_cann_create_tensor(src);
  981. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  982. float eps;
  983. memcpy(&eps, dst->op_params, sizeof(float));
  984. GGML_ASSERT(eps > 0.0f);
  985. uint64_t workspaceSize = 0;
  986. aclOpExecutor* executor;
  987. void* workspaceAddr = nullptr;
  988. size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
  989. ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
  990. aclTensor* acl_gamma = aclnn_ones(
  991. ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1,
  992. ggml_cann_type_mapping(src->type), ggml_element_size(src));
  993. size_t zero_tensor_n_bytes =
  994. src->ne[1] * src->ne[2] * src->ne[3] * ggml_element_size(src);
  995. ggml_cann_pool_alloc zero_tensor_allocator(ctx.pool(), zero_tensor_n_bytes);
  996. aclTensor* acl_rstd =
  997. aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
  998. src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
  999. ggml_element_size(src));
  1000. ACL_CHECK(aclnnRmsNormGetWorkspaceSize(
  1001. acl_src, acl_gamma, eps, acl_dst, acl_rstd, &workspaceSize, &executor));
  1002. if (workspaceSize > 0) {
  1003. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1004. workspaceAddr = workspace_allocator.get();
  1005. }
  1006. ACL_CHECK(
  1007. aclnnRmsNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1008. ACL_CHECK(aclDestroyTensor(acl_src));
  1009. ACL_CHECK(aclDestroyTensor(acl_dst));
  1010. ACL_CHECK(aclDestroyTensor(acl_gamma));
  1011. ACL_CHECK(aclDestroyTensor(acl_rstd));
  1012. }
  1013. // TODO: performace is low.
  1014. void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
  1015. float value) {
  1016. ggml_tensor* src = dst->src[0];
  1017. aclTensor* acl_src = ggml_cann_create_tensor(src);
  1018. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  1019. const int n_past = ((int32_t*)dst->op_params)[0];
  1020. size_t one_tensor_n_bytes = src->ne[0] * src->ne[1] * src->ne[2] *
  1021. src->ne[3] * ggml_element_size(src);
  1022. ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
  1023. aclTensor* mask_tensor =
  1024. aclnn_ones(ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne,
  1025. GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
  1026. ggml_element_size(src), value);
  1027. uint64_t workspaceSize = 0;
  1028. aclOpExecutor* executor;
  1029. void* workspaceAddr = nullptr;
  1030. ACL_CHECK(aclnnInplaceTriuGetWorkspaceSize(mask_tensor, n_past + 1,
  1031. &workspaceSize, &executor));
  1032. if (workspaceSize > 0) {
  1033. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1034. workspaceAddr = workspace_allocator.get();
  1035. }
  1036. ACL_CHECK(
  1037. aclnnInplaceTriu(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1038. ACL_CHECK(aclnnTrilGetWorkspaceSize(acl_src, n_past + 1, acl_dst,
  1039. &workspaceSize, &executor));
  1040. if (workspaceSize > 0) {
  1041. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1042. workspaceAddr = workspace_allocator.get();
  1043. }
  1044. ACL_CHECK(aclnnTril(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1045. aclScalar* alpha = nullptr;
  1046. float alphaValue = 1.0f;
  1047. alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
  1048. ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, mask_tensor, alpha,
  1049. &workspaceSize, &executor));
  1050. if (workspaceSize > 0) {
  1051. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1052. workspaceAddr = workspace_allocator.get();
  1053. }
  1054. ACL_CHECK(
  1055. aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1056. ACL_CHECK(aclDestroyScalar(alpha));
  1057. ACL_CHECK(aclDestroyTensor(mask_tensor));
  1058. ACL_CHECK(aclDestroyTensor(acl_src));
  1059. ACL_CHECK(aclDestroyTensor(acl_dst));
  1060. }
  1061. /**
  1062. * @brief Casts the data type of a source tensor to a destination tensor.
  1063. *
  1064. * This function casts the data type of the source tensor `acl_src` to the
  1065. * specified data type `cast_data_type` and stores the result in the destination
  1066. * tensor `acl_dst`.
  1067. *
  1068. * @param ctx The context for the CANN backend operations.
  1069. * @param acl_src The source tensor whose data type will be casted.
  1070. * @param acl_dst The destination tensor where the casted result will be stored.
  1071. * @param cast_data_type The target data type to which the source tensor will be
  1072. * casted.
  1073. */
  1074. static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  1075. aclTensor* acl_dst, aclDataType cast_data_type) {
  1076. uint64_t workspaceSize = 0;
  1077. aclOpExecutor* executor;
  1078. void* workspaceAddr = nullptr;
  1079. ACL_CHECK(aclnnCastGetWorkspaceSize(acl_src, cast_data_type, acl_dst,
  1080. &workspaceSize, &executor));
  1081. if (workspaceSize > 0) {
  1082. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1083. workspaceAddr = workspace_allocator.get();
  1084. }
  1085. ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1086. }
  1087. /**
  1088. * @brief Permutes the dimensions of a tensor according to a specified order.
  1089. *
  1090. * This function permutes the dimensions of the source tensor `acl_src`
  1091. * according to the order specified in the `new_dim` array and stores the result
  1092. * in the destination tensor `acl_dst`.
  1093. *
  1094. * @param ctx The context for the CANN backend operations.
  1095. * @param acl_src The source tensor whose dimensions will be permuted.
  1096. * @param acl_dst The destination tensor where the permuted result will be
  1097. * stored.
  1098. * @param new_dim An array specifying the new order of dimensions for the
  1099. * tensor.
  1100. * @param dims The number of dimensions in the tensor.
  1101. */
  1102. static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  1103. aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) {
  1104. aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims);
  1105. uint64_t workspaceSize = 0;
  1106. aclOpExecutor* executor;
  1107. void* workspaceAddr = nullptr;
  1108. ACL_CHECK(aclnnPermuteGetWorkspaceSize(acl_src, acl_dims, acl_dst,
  1109. &workspaceSize, &executor));
  1110. if (workspaceSize > 0) {
  1111. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1112. workspaceAddr = workspace_allocator.get();
  1113. }
  1114. ACL_CHECK(
  1115. aclnnPermute(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1116. ACL_CHECK(aclDestroyIntArray(acl_dims));
  1117. }
  1118. #ifdef __cplusplus
  1119. extern "C" {
  1120. #endif
  1121. aclnnStatus aclnnIm2colGetWorkspaceSize(const aclTensor* self,
  1122. const aclIntArray* kernelSize,
  1123. const aclIntArray* dilation,
  1124. const aclIntArray* padding,
  1125. const aclIntArray* stride,
  1126. aclTensor* out, uint64_t* workspaceSize,
  1127. aclOpExecutor** executor);
  1128. aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
  1129. aclOpExecutor* executor, aclrtStream stream);
  1130. #ifdef __cplusplus
  1131. }
  1132. #endif
  1133. static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
  1134. ggml_tensor* dst,
  1135. ggml_tensor* src1,
  1136. aclTensor* tmp_cast_tensor,
  1137. aclTensor* tmp_im2col_tensor) {
  1138. // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
  1139. int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
  1140. size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
  1141. aclTensor* acl_dst =
  1142. ggml_cann_create_tensor(dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1);
  1143. int64_t permute_dim[] = {0, 2, 1};
  1144. if (src1->type != dst->type) {
  1145. aclnn_permute(ctx, tmp_cast_tensor, acl_dst, permute_dim, 3);
  1146. } else {
  1147. aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3);
  1148. }
  1149. // release
  1150. ACL_CHECK(aclDestroyTensor(acl_dst));
  1151. }
  1152. static void ggml_cann_im2col_1d_post_process(
  1153. ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1,
  1154. aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor,
  1155. const std::vector<int64_t>& im2col_op_params) {
  1156. // get params
  1157. const int64_t KH = im2col_op_params[0];
  1158. const int64_t KW = im2col_op_params[1];
  1159. const int64_t IW = im2col_op_params[2];
  1160. const int64_t IC = im2col_op_params[3];
  1161. const int64_t N = im2col_op_params[4];
  1162. const int64_t OH = im2col_op_params[5];
  1163. const int64_t OW = im2col_op_params[6];
  1164. const int64_t s0 = im2col_op_params[7];
  1165. const int64_t p0 = im2col_op_params[8];
  1166. const int64_t d0 = im2col_op_params[9];
  1167. const int64_t n_bytes_factor = im2col_op_params[10];
  1168. // Permute: [N, IC * KH * KW, OW * OH] ->
  1169. // [N, OW * OH * n_bytes_factor, IC * KH * KW]
  1170. aclTensor* tmp_permute_tensor = nullptr;
  1171. ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool());
  1172. tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);
  1173. void* tmp_permute_buffer = tmp_permute_allocator.get();
  1174. int64_t tmp_permute_ne[] = {IC * KH * KW, OW * OH * n_bytes_factor, N};
  1175. size_t tmp_permute_nb[GGML_MAX_DIMS - 1];
  1176. tmp_permute_nb[0] = ggml_type_size(dst->type);
  1177. for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
  1178. tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
  1179. }
  1180. tmp_permute_tensor = ggml_cann_create_tensor(
  1181. tmp_permute_buffer, ggml_cann_type_mapping(dst->type),
  1182. ggml_type_size(dst->type), tmp_permute_ne, tmp_permute_nb,
  1183. GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
  1184. int64_t permute_dim[] = {0, 2, 1};
  1185. if (src1->type != dst->type) {
  1186. aclnn_permute(ctx, tmp_cast_tensor, tmp_permute_tensor, permute_dim, 3);
  1187. } else {
  1188. aclnn_permute(ctx, tmp_im2col_tensor, tmp_permute_tensor, permute_dim,
  1189. 3);
  1190. }
  1191. // number of times the kernel moves in W dimension
  1192. const int n_step_w = (IW + 2 * p0 - d0 * (KW - 1) - 1) / s0 + 1;
  1193. size_t offset;
  1194. void *cur_dst_buffer = dst->data, *cur_permute_buffer = tmp_permute_buffer;
  1195. // memory copy with offset to restore 1D im2col from 2d
  1196. if (IC > 1) {
  1197. offset = IC * KH * KW * n_step_w * ggml_type_size(dst->type);
  1198. size_t size_cpy = KH * KW * ggml_type_size(dst->type);
  1199. for (int c = 0; c < IC; c++) {
  1200. cur_permute_buffer = (char*)tmp_permute_buffer + offset +
  1201. KH * KW * c * ggml_type_size(dst->type);
  1202. cur_dst_buffer = (char*)dst->data +
  1203. c * KH * KW * n_step_w * ggml_type_size(dst->type);
  1204. for (int i = 0; i < n_step_w; i++) {
  1205. ACL_CHECK(aclrtMemcpyAsync(
  1206. cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
  1207. ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
  1208. cur_dst_buffer =
  1209. (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
  1210. cur_permute_buffer = (char*)cur_permute_buffer +
  1211. KH * KW * IC * ggml_type_size(dst->type);
  1212. }
  1213. }
  1214. } else {
  1215. offset = KH * KW * n_step_w *
  1216. ggml_type_size(dst->type); // equal to ggml_nbytes(dst)
  1217. ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
  1218. (char*)tmp_permute_buffer + offset, offset,
  1219. ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
  1220. }
  1221. // release
  1222. ACL_CHECK(aclDestroyTensor(tmp_permute_tensor));
  1223. }
  1224. void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  1225. ggml_tensor* src0 = dst->src[0]; // kernel
  1226. ggml_tensor* src1 = dst->src[1]; // input
  1227. GGML_ASSERT(src0->type == GGML_TYPE_F16);
  1228. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1229. GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
  1230. GGML_TENSOR_BINARY_OP_LOCALS;
  1231. // aclnnIm2col only works on 2D. set s1, p1, d1 to 1 to perform 2D
  1232. // im2col and do post-processing to restore it to 1D.
  1233. const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
  1234. const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
  1235. const int32_t s1 = is_2D ? ((const int32_t*)(dst->op_params))[1] : 1;
  1236. const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
  1237. const int32_t p1 = is_2D ? ((const int32_t*)(dst->op_params))[3] : 1;
  1238. const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
  1239. const int32_t d1 = is_2D ? ((const int32_t*)(dst->op_params))[5] : 1;
  1240. const int64_t N = ne13;
  1241. const int64_t IC = ne12;
  1242. const int64_t KH = ne01;
  1243. const int64_t KW = ne00;
  1244. const int64_t IW = ne10;
  1245. const int64_t OH = is_2D ? ne2 : 1;
  1246. const int64_t OW = ne1;
  1247. GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  1248. GGML_ASSERT(nb10 == sizeof(float));
  1249. // memory allocated increased to 3x when is_2D == false
  1250. const int64_t n_bytes_factor = is_2D ? 1 : 3;
  1251. // im2col: [N,C,H,W] -> [N, IC * KH * KW, OW * OH * n_bytes_factor]
  1252. aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
  1253. int64_t tmp_im2col_ne[] = {OW * OH * n_bytes_factor, IC * KH * KW, N};
  1254. size_t tmp_im2col_nb[GGML_MAX_DIMS - 1];
  1255. tmp_im2col_nb[0] = ggml_type_size(src1->type);
  1256. for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
  1257. tmp_im2col_nb[i] = tmp_im2col_nb[i - 1] * tmp_im2col_ne[i - 1];
  1258. }
  1259. // Calculate im2col.
  1260. // If dst is f16, tmp_buffer is f32, we need alloc src.typesize *
  1261. // dst.elemcount.
  1262. ggml_cann_pool_alloc im2col_allocator(
  1263. ctx.pool(),
  1264. ggml_nelements(dst) * ggml_element_size(src1) * n_bytes_factor);
  1265. void* tmp_im2col_buffer = im2col_allocator.get();
  1266. aclTensor* tmp_im2col_tensor = ggml_cann_create_tensor(
  1267. tmp_im2col_buffer, ggml_cann_type_mapping(src1->type),
  1268. ggml_type_size(src1->type), tmp_im2col_ne, tmp_im2col_nb,
  1269. GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
  1270. std::vector<int64_t> kernel_dims = {KH, KW};
  1271. std::vector<int64_t> dilation_size = {d1, d0};
  1272. std::vector<int64_t> padding_dims = {p1, p0};
  1273. std::vector<int64_t> stride_dims = {s1, s0};
  1274. auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2);
  1275. auto* dilations = aclCreateIntArray(dilation_size.data(), 2);
  1276. auto* paddings = aclCreateIntArray(padding_dims.data(), 2);
  1277. auto* strides = aclCreateIntArray(stride_dims.data(), 2);
  1278. uint64_t workspaceSize = 0;
  1279. aclOpExecutor* executor;
  1280. void* workspaceAddr = nullptr;
  1281. ACL_CHECK(aclnnIm2colGetWorkspaceSize(acl_src1, kernel_size, dilations,
  1282. paddings, strides, tmp_im2col_tensor,
  1283. &workspaceSize, &executor));
  1284. ggml_cann_pool_alloc workspace_allocator(ctx.pool());
  1285. if (workspaceSize > 0) {
  1286. workspace_allocator.alloc(workspaceSize);
  1287. workspaceAddr = workspace_allocator.get();
  1288. }
  1289. ACL_CHECK(
  1290. aclnnIm2col(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1291. // Cast if dst is f16.
  1292. aclTensor* tmp_cast_tensor = nullptr;
  1293. ggml_cann_pool_alloc tmp_cast_allocator(ctx.pool());
  1294. void* tmp_cast_buffer = nullptr;
  1295. if (src1->type != dst->type) {
  1296. tmp_cast_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);
  1297. tmp_cast_buffer = tmp_cast_allocator.get();
  1298. size_t temp_cast_nb[GGML_MAX_DIMS - 1];
  1299. temp_cast_nb[0] = ggml_type_size(dst->type);
  1300. for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
  1301. temp_cast_nb[i] = temp_cast_nb[i - 1] * tmp_im2col_ne[i - 1];
  1302. }
  1303. tmp_cast_tensor = ggml_cann_create_tensor(
  1304. tmp_cast_buffer, ggml_cann_type_mapping(dst->type),
  1305. ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb,
  1306. GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
  1307. aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor,
  1308. ggml_cann_type_mapping(dst->type));
  1309. }
  1310. // post-processing
  1311. if (is_2D) {
  1312. ggml_cann_im2col_2d_post_process(ctx, dst, src1, tmp_cast_tensor,
  1313. tmp_im2col_tensor);
  1314. } else {
  1315. std::vector<int64_t> im2col_op_params = {
  1316. KH, KW, IW, IC, N, OH, OW, s0, p0, d0, n_bytes_factor};
  1317. ggml_cann_im2col_1d_post_process(ctx, dst, src1, tmp_cast_tensor,
  1318. tmp_im2col_tensor, im2col_op_params);
  1319. }
  1320. // release
  1321. ACL_CHECK(aclDestroyTensor(acl_src1));
  1322. ACL_CHECK(aclDestroyTensor(tmp_im2col_tensor));
  1323. ACL_CHECK(aclDestroyTensor(tmp_cast_tensor));
  1324. ACL_CHECK(aclDestroyIntArray(kernel_size));
  1325. ACL_CHECK(aclDestroyIntArray(dilations));
  1326. ACL_CHECK(aclDestroyIntArray(paddings));
  1327. ACL_CHECK(aclDestroyIntArray(strides));
  1328. }
  1329. /**
  1330. * @brief Applies element-wise exponential function to the elements of a tensor.
  1331. *
  1332. * This function computes the exponential of each element in the source tensor
  1333. * `acl_src` and stores the result back into the same tensor.
  1334. * The operation is defined as:
  1335. * \f[
  1336. * \text {acl_src }_i=e^{acl\_src_i}
  1337. * \f]
  1338. *
  1339. * @param ctx The context for the CANN backend operations.
  1340. * @param acl_src The tensor on which the exponential function will be applied.
  1341. */
  1342. static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
  1343. uint64_t workspaceSize = 0;
  1344. aclOpExecutor* executor;
  1345. void* workspaceAddr = nullptr;
  1346. ACL_CHECK(
  1347. aclnnInplaceExpGetWorkspaceSize(acl_src, &workspaceSize, &executor));
  1348. if (workspaceSize > 0) {
  1349. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1350. workspaceAddr = workspace_allocator.get();
  1351. }
  1352. ACL_CHECK(
  1353. aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1354. }
  1355. /**
  1356. * @brief Multiplies elements of a tensor by a scalar value, optionally
  1357. * in-place.
  1358. *
  1359. * This function multiplies each element of the source tensor `acl_src` by the
  1360. * scalar `scale` and stores the result in the destination tensor `acl_dst`. If
  1361. * `inplace` is true, `acl_dst` will not be used and the operation is performed
  1362. * in-place on `acl_src`.
  1363. * The operation is defined as:
  1364. * \f[
  1365. * \text {acl_dst }_i=\text {acl_src }_i \times \text {scale}
  1366. * \f]
  1367. *
  1368. * @param ctx The context for the CANN backend operations.
  1369. * @param acl_src The source tensor whose elements will be multiplied.
  1370. * @param scale The scalar value by which each element of `acl_src` will be
  1371. * multiplied.
  1372. * @param acl_dst The destination tensor where the result will be stored if
  1373. * `inplace` is false.
  1374. * @param inplace Flag indicating whether to perform the operation in-place on
  1375. * `acl_src`.
  1376. */
  1377. static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  1378. float scale, aclTensor* acl_dst, bool inplace) {
  1379. aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
  1380. uint64_t workspaceSize = 0;
  1381. aclOpExecutor* executor;
  1382. void* workspaceAddr = nullptr;
  1383. if (inplace) {
  1384. ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale,
  1385. &workspaceSize, &executor));
  1386. if (workspaceSize > 0) {
  1387. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1388. workspaceAddr = workspace_allocator.get();
  1389. }
  1390. ACL_CHECK(aclnnInplaceMuls(workspaceAddr, workspaceSize, executor,
  1391. ctx.stream()));
  1392. } else {
  1393. ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, acl_scale, acl_dst,
  1394. &workspaceSize, &executor));
  1395. if (workspaceSize > 0) {
  1396. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1397. workspaceAddr = workspace_allocator.get();
  1398. }
  1399. ACL_CHECK(
  1400. aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1401. }
  1402. ACL_CHECK(aclDestroyScalar(acl_scale));
  1403. }
  1404. /**
  1405. * @brief Performs an in-place element-wise multiplication of two tensors.
  1406. *
  1407. * This function performs an element-wise multiplication of the tensors
  1408. * `acl_src` and `acl_other` and stores the result in `acl_src`.
  1409. * The operation is defined as:
  1410. * \f[
  1411. * \text {acl_src }_i=\text {acl_src }_i \times \text {acl_other }_i
  1412. * \f]
  1413. *
  1414. * @param ctx The context for the CANN backend operations.
  1415. * @param acl_src The source tensor where the multiplication result will be
  1416. * stored.
  1417. * @param acl_other The tensor whose elements will be multiplied with `acl_src`.
  1418. */
  1419. static void aclnn_inplace_mul(ggml_backend_cann_context& ctx,
  1420. aclTensor* acl_src, aclTensor* acl_other) {
  1421. uint64_t workspaceSize = 0;
  1422. aclOpExecutor* executor;
  1423. void* workspaceAddr = nullptr;
  1424. ACL_CHECK(aclnnInplaceMulGetWorkspaceSize(acl_src, acl_other,
  1425. &workspaceSize, &executor));
  1426. if (workspaceSize > 0) {
  1427. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1428. workspaceAddr = workspace_allocator.get();
  1429. }
  1430. ACL_CHECK(
  1431. aclnnInplaceMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1432. }
  1433. /**
  1434. * @brief Performs element-wise multiplication of two tensors and stores the
  1435. * result in a destination tensor.
  1436. *
  1437. * This function performs element-wise multiplication of the tensors `acl_src`
  1438. * and `acl_other` and stores the result in the destination tensor `acl_dst`.
  1439. * The operation is defined as:
  1440. * \f[
  1441. * \text {acl_dst }_i=\text {acl_src }_i \times \text {acl_other }_i
  1442. * \f]
  1443. *
  1444. * @param ctx The context for the CANN backend operations.
  1445. * @param acl_src The first tensor for element-wise multiplication.
  1446. * @param acl_other The second tensor for element-wise multiplication.
  1447. * @param acl_dst The destination tensor where the result will be stored.
  1448. */
  1449. static void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  1450. aclTensor* acl_other, aclTensor* acl_dst) {
  1451. uint64_t workspaceSize = 0;
  1452. aclOpExecutor* executor;
  1453. void* workspaceAddr = nullptr;
  1454. ACL_CHECK(aclnnMulGetWorkspaceSize(acl_src, acl_other, acl_dst,
  1455. &workspaceSize, &executor));
  1456. if (workspaceSize > 0) {
  1457. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1458. workspaceAddr = workspace_allocator.get();
  1459. }
  1460. ACL_CHECK(aclnnMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1461. }
  1462. /**
  1463. * @brief Applies element-wise cosine function to the elements of a tensor.
  1464. *
  1465. * This function computes the cosine of each element in the source tensor
  1466. * `acl_src` and stores the result in the destination tensor `acl_dst`. The
  1467. * operation is defined as: \f[ \text {acl_dst }_i=\cos \left(\text {acl_src
  1468. * }_i\right) \f]
  1469. *
  1470. * @param ctx The context for the CANN backend operations.
  1471. * @param acl_src The source tensor on which the cosine function will be
  1472. * applied.
  1473. * @param acl_dst The destination tensor where the cosine results will be
  1474. * stored.
  1475. */
  1476. static void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  1477. aclTensor* acl_dst) {
  1478. uint64_t workspaceSize = 0;
  1479. aclOpExecutor* executor;
  1480. void* workspaceAddr = nullptr;
  1481. ACL_CHECK(
  1482. aclnnCosGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
  1483. if (workspaceSize > 0) {
  1484. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1485. workspaceAddr = workspace_allocator.get();
  1486. }
  1487. ACL_CHECK(aclnnCos(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1488. }
  1489. /**
  1490. * @brief Applies element-wise sine function to the elements of a tensor.
  1491. *
  1492. * This function computes the sine of each element in the source tensor
  1493. `acl_src`
  1494. * and stores the result in the destination tensor `acl_dst`.
  1495. * The operation is defined as:
  1496. * \f[
  1497. * \text {acl_dst }_i=\sin \left(\text {acl_src }_i\right)
  1498. * \f]
  1499. * @param ctx The context for the CANN backend operations.
  1500. * @param acl_src The source tensor on which the sine function will be applied.
  1501. * @param acl_dst The destination tensor where the sine results will be stored.
  1502. */
  1503. static void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  1504. aclTensor* acl_dst) {
  1505. uint64_t workspaceSize = 0;
  1506. aclOpExecutor* executor;
  1507. void* workspaceAddr = nullptr;
  1508. ACL_CHECK(
  1509. aclnnSinGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
  1510. if (workspaceSize > 0) {
  1511. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1512. workspaceAddr = workspace_allocator.get();
  1513. }
  1514. ACL_CHECK(aclnnSin(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1515. }
  1516. void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
  1517. ggml_tensor* dst) {
  1518. const ggml_tensor* src = dst->src[0];
  1519. GGML_ASSERT(src->type == GGML_TYPE_F32);
  1520. GGML_ASSERT(dst->type == GGML_TYPE_F32);
  1521. const int dim = dst->op_params[0];
  1522. const int max_period = dst->op_params[1];
  1523. int half = dim / 2;
  1524. aclTensor* acl_src = ggml_cann_create_tensor(src);
  1525. // arange: [0, ..., half)
  1526. float start = 0;
  1527. float stop = half;
  1528. float step = 1;
  1529. int64_t n_elements_arange = half;
  1530. int64_t tmp_arange_ne[] = {half};
  1531. size_t tmp_arange_nb[] = {sizeof(dst->type)};
  1532. ggml_cann_pool_alloc arange_allocator(ctx.pool(), half * sizeof(dst->type));
  1533. void* tmp_arange_buffer = arange_allocator.get();
  1534. aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
  1535. tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
  1536. ggml_type_size(dst->type), tmp_arange_ne, tmp_arange_nb,
  1537. GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
  1538. aclnn_arange(ctx, tmp_arange_tensor, start, stop, step, n_elements_arange);
  1539. // freq
  1540. float freq_param = -logf(max_period) / half;
  1541. bool inplace = true;
  1542. aclnn_muls(ctx, tmp_arange_tensor, freq_param, nullptr, inplace);
  1543. aclnn_exp(ctx, tmp_arange_tensor);
  1544. // permute: src [0,1,2,3]->[0,1,3,2]
  1545. int64_t tmp_permute_ne[] = {src->ne[1], src->ne[0], src->ne[2], src->ne[3]};
  1546. size_t tmp_permute_nb[GGML_MAX_DIMS];
  1547. tmp_permute_nb[0] = ggml_type_size(src->type);
  1548. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  1549. tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
  1550. }
  1551. ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src));
  1552. void* tmp_permute_buffer = permute_allocator.get();
  1553. aclTensor* tmp_permute_tenosr = ggml_cann_create_tensor(
  1554. tmp_permute_buffer, ggml_cann_type_mapping(src->type),
  1555. ggml_type_size(src->type), tmp_permute_ne, tmp_permute_nb,
  1556. GGML_MAX_DIMS, ACL_FORMAT_ND);
  1557. int64_t permute_dim[] = {0, 1, 3, 2};
  1558. int64_t num_dims = 4;
  1559. aclnn_permute(ctx, acl_src, tmp_permute_tenosr, permute_dim, num_dims);
  1560. // timestep * freq
  1561. int64_t tmp_mul_ne[] = {src->ne[1] * half, src->ne[0], src->ne[2],
  1562. src->ne[3]};
  1563. size_t tmp_mul_nb[GGML_MAX_DIMS];
  1564. tmp_mul_nb[0] = ggml_type_size(src->type);
  1565. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  1566. tmp_mul_nb[i] = tmp_mul_nb[i - 1] * tmp_mul_ne[i - 1];
  1567. }
  1568. int mul_nelements =
  1569. src->ne[1] * half * src->ne[0] * src->ne[2] * src->ne[3];
  1570. ggml_cann_pool_alloc mul_allocator(
  1571. ctx.pool(), mul_nelements * ggml_type_size(src->type));
  1572. void* tmp_mul_buffer = mul_allocator.get();
  1573. aclTensor* tmp_mul_tensor = ggml_cann_create_tensor(
  1574. tmp_mul_buffer, ggml_cann_type_mapping(src->type),
  1575. ggml_type_size(src->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
  1576. ACL_FORMAT_ND);
  1577. aclnn_mul(ctx, tmp_permute_tenosr, tmp_arange_tensor, tmp_mul_tensor);
  1578. // cos
  1579. ggml_cann_pool_alloc cos_allocator(
  1580. ctx.pool(), mul_nelements * ggml_type_size(src->type));
  1581. void* tmp_cos_buffer = cos_allocator.get();
  1582. aclTensor* tmp_cos_tensor = ggml_cann_create_tensor(
  1583. tmp_cos_buffer, ggml_cann_type_mapping(dst->type),
  1584. ggml_type_size(dst->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
  1585. ACL_FORMAT_ND);
  1586. aclnn_cos(ctx, tmp_mul_tensor, tmp_cos_tensor);
  1587. // sin
  1588. ggml_cann_pool_alloc sin_allocator(
  1589. ctx.pool(), mul_nelements * ggml_type_size(src->type));
  1590. void* tmp_sin_buffer = sin_allocator.get();
  1591. aclTensor* tmp_sin_tensor = ggml_cann_create_tensor(
  1592. tmp_sin_buffer, ggml_cann_type_mapping(dst->type),
  1593. ggml_type_size(dst->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
  1594. ACL_FORMAT_ND);
  1595. aclnn_sin(ctx, tmp_mul_tensor, tmp_sin_tensor);
  1596. // concat
  1597. int64_t concat_dim = 3;
  1598. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  1599. aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor};
  1600. aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
  1601. aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
  1602. // release
  1603. // segmentation fault when delete both tensorList and his elements.
  1604. ACL_CHECK(aclDestroyTensorList(tensorList));
  1605. ACL_CHECK(aclDestroyTensor(acl_src));
  1606. ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
  1607. ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr));
  1608. ACL_CHECK(aclDestroyTensor(tmp_mul_tensor));
  1609. ACL_CHECK(aclDestroyTensor(acl_dst));
  1610. }
  1611. /**
  1612. * @brief Fills a tensor with a scalar value.
  1613. *
  1614. * This function fills the destination tensor `acl_dst` with the scalar value
  1615. * `scalar`.
  1616. *
  1617. * @param ctx The context for the CANN backend operations.
  1618. * @param scalar The scalar value used to fill the tensor.
  1619. * @param acl_dst The destination tensor to be filled with the scalar value.
  1620. */
  1621. static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
  1622. aclTensor* acl_dst) {
  1623. auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
  1624. uint64_t workspaceSize = 0;
  1625. aclOpExecutor* executor;
  1626. void* workspaceAddr = nullptr;
  1627. ACL_CHECK(aclnnInplaceFillScalarGetWorkspaceSize(
  1628. acl_dst, acl_scalar, &workspaceSize, &executor));
  1629. if (workspaceSize > 0) {
  1630. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1631. workspaceAddr = workspace_allocator.get();
  1632. }
  1633. ACL_CHECK(aclnnInplaceFillScalar(workspaceAddr, workspaceSize, executor,
  1634. ctx.stream()));
  1635. ACL_CHECK(aclDestroyScalar(acl_scalar));
  1636. }
  1637. /**
  1638. * @brief Raises each element of a tensor to the power of the corresponding
  1639. * element in another tensor.
  1640. *
  1641. * This function computes the element-wise power of the destination tensor
  1642. * `acl_dst` raised to the power of the exponent tensor `acl_exp`.
  1643. * The operation is defined as:
  1644. * \f[
  1645. * \text {acl_dst }_i=acl\_dst_i^{\text {acl_exp }_i}
  1646. * \f]
  1647. *
  1648. * @param ctx The context for the CANN backend operations.
  1649. * @param acl_dst The destination tensor, which also serves as the base tensor.
  1650. * @param acl_exp The exponent tensor, each element of which is used to raise
  1651. * the corresponding element in the destination tensor.
  1652. */
  1653. static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
  1654. aclTensor* acl_dst, aclTensor* acl_exp) {
  1655. uint64_t workspaceSize = 0;
  1656. aclOpExecutor* executor;
  1657. void* workspaceAddr = nullptr;
  1658. ACL_CHECK(aclnnInplacePowTensorTensorGetWorkspaceSize(
  1659. acl_dst, acl_exp, &workspaceSize, &executor));
  1660. if (workspaceSize > 0) {
  1661. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1662. workspaceAddr = workspace_allocator.get();
  1663. }
  1664. ACL_CHECK(aclnnInplacePowTensorTensor(workspaceAddr, workspaceSize,
  1665. executor, ctx.stream()));
  1666. }
  1667. /**
  1668. * @brief Applies the Alibi (Attention with Linear Biases) mechanism to the
  1669. * @details This function implements the Alibi mechanism, which introduces
  1670. * learnable biases into the attention scores to simulate relative
  1671. * position encoding without the need for explicit positional
  1672. * embeddings.
  1673. *
  1674. * @param ctx The backend CANN context for executing operations.
  1675. * @param acl_src The source tensor representing the query or key.
  1676. * @param acl_position The position tensor containing relative positions.
  1677. * @param acl_dst The destination tensor where the result will be stored.
  1678. * @param n_head The number of attention heads.
  1679. * @param src_ne The dimensions of the source tensor.
  1680. * @param src_nb0 The byte size of the first dimension of the source
  1681. tensor.
  1682. * @param max_bias The maximum bias value used in the Alibi mechanism.
  1683. * @param dst The destination tensor object for additional metadata.
  1684. *
  1685. * The function performs the following steps:
  1686. * 1. Calculates the logarithm floor of the number of heads to determine the
  1687. base for bias calculation.
  1688. * 2. Initializes arrays with arithmetic sequences and fills them with bias
  1689. values.
  1690. * 3. Computes the bias tensor based on the calculated biases and arithmetic
  1691. sequences.
  1692. * 4. Reshapes the bias tensor to match the dimensions of the input tensors.
  1693. * 5. Multiplies the position tensor by the bias tensor.
  1694. * 6. Adds the result of the multiplication to the source tensor to produce the
  1695. final output.
  1696. */
  1697. static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  1698. aclTensor* acl_position, aclTensor* acl_dst,
  1699. const int n_head, int64_t* src_ne, const size_t src_nb0,
  1700. float max_bias, ggml_tensor* dst) {
  1701. const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
  1702. GGML_ASSERT(src_nb0 == sizeof(float));
  1703. GGML_ASSERT(n_head == src_ne[2]);
  1704. const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
  1705. float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
  1706. float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
  1707. // init arange
  1708. ggml_cann_pool_alloc arange_allocator(ctx.pool(),
  1709. ne2_ne3 * ggml_type_size(dst->type));
  1710. void* tmp_arange_buffer = arange_allocator.get();
  1711. // arange1: [1, ..., n_heads_log2_floor+1)
  1712. float start = 1;
  1713. float stop = n_heads_log2_floor + 1;
  1714. float step = 1;
  1715. int64_t n_elements_arange = n_heads_log2_floor;
  1716. int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
  1717. size_t tmp_arange1_nb[] = {sizeof(dst->type)};
  1718. aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
  1719. tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
  1720. ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
  1721. GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
  1722. aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
  1723. aclTensor* tmp_arange2_tensor = nullptr;
  1724. if (n_heads_log2_floor < ne2_ne3) {
  1725. // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
  1726. start = 1;
  1727. stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
  1728. step = 2;
  1729. n_elements_arange = ne2_ne3 - n_heads_log2_floor;
  1730. int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
  1731. size_t tmp_arange2_nb[] = {sizeof(dst->type)};
  1732. aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
  1733. (char*)tmp_arange_buffer +
  1734. n_heads_log2_floor * ggml_type_size(dst->type),
  1735. ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
  1736. tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
  1737. aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
  1738. n_elements_arange);
  1739. }
  1740. // init mk_base
  1741. ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
  1742. ne2_ne3 * ggml_type_size(dst->type));
  1743. void* tmp_mk_base_buffer = mk_base_allocator.get();
  1744. int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
  1745. size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
  1746. aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
  1747. tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
  1748. ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
  1749. GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
  1750. aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
  1751. aclTensor* tmp_mk_base2_tensor = nullptr;
  1752. if (n_heads_log2_floor < ne2_ne3) {
  1753. int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
  1754. size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
  1755. aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
  1756. (char*)tmp_mk_base_buffer +
  1757. n_heads_log2_floor * ggml_type_size(dst->type),
  1758. ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
  1759. tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
  1760. aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
  1761. }
  1762. // init mk
  1763. int64_t tmp_mk_base_ne[] = {ne2_ne3};
  1764. size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
  1765. aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
  1766. tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
  1767. ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
  1768. GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
  1769. aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
  1770. tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
  1771. ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
  1772. GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
  1773. aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
  1774. // reshape mk
  1775. int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
  1776. size_t tmp_mk_nb[GGML_MAX_DIMS];
  1777. tmp_mk_nb[0] = ggml_type_size(dst->type);
  1778. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  1779. tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
  1780. }
  1781. aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
  1782. tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
  1783. ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
  1784. ACL_FORMAT_ND);
  1785. // acl_position * mk
  1786. int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
  1787. size_t tmp_output_nb[GGML_MAX_DIMS];
  1788. tmp_output_nb[0] = ggml_type_size(dst->type);
  1789. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  1790. tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
  1791. }
  1792. ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
  1793. void* tmp_output_buffer = output_allocator.get();
  1794. aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
  1795. tmp_output_buffer, ggml_cann_type_mapping(dst->type),
  1796. ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
  1797. ACL_FORMAT_ND);
  1798. aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
  1799. // add
  1800. aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
  1801. ACL_CHECK(aclDestroyTensor(tmp_arange1_tensor));
  1802. ACL_CHECK(aclDestroyTensor(tmp_arange2_tensor));
  1803. ACL_CHECK(aclDestroyTensor(tmp_mk_base1_tensor));
  1804. ACL_CHECK(aclDestroyTensor(tmp_mk_base2_tensor));
  1805. ACL_CHECK(aclDestroyTensor(tmp_mk_base_tensor));
  1806. ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
  1807. ACL_CHECK(aclDestroyTensor(tmp_mk_tensor));
  1808. ACL_CHECK(aclDestroyTensor(tmp_output_tensor));
  1809. }
  1810. void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  1811. ggml_cann_dup(ctx, dst);
  1812. }
  1813. /**
  1814. * @brief Performs element-wise addition of two tensors in place.
  1815. *
  1816. * This function adds the source tensor `acl_src` to the destination tensor
  1817. * `acl_dst` element-wise and stores the result in the destination tensor
  1818. * `acl_dst`.
  1819. *
  1820. * @param ctx The context for the CANN backend operations.
  1821. * @param acl_src The source tensor to be added.
  1822. * @param acl_dst The destination tensor which will hold the result of the
  1823. * addition.
  1824. */
  1825. static void aclnn_inplace_add(ggml_backend_cann_context& ctx,
  1826. aclTensor* acl_src, aclTensor* acl_dst) {
  1827. aclScalar* alpha = nullptr;
  1828. float alphaValue = 1.0f;
  1829. alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
  1830. uint64_t workspaceSize = 0;
  1831. aclOpExecutor* executor;
  1832. void* workspaceAddr = nullptr;
  1833. ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src, alpha,
  1834. &workspaceSize, &executor));
  1835. if (workspaceSize > 0) {
  1836. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1837. workspaceAddr = workspace_allocator.get();
  1838. }
  1839. ACL_CHECK(
  1840. aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
  1841. ACL_CHECK(aclDestroyScalar(alpha));
  1842. }
  1843. /**
  1844. * @brief Applies the softmax function to a tensor along a specified dimension.
  1845. *
  1846. * This function computes the softmax of the source tensor `acl_src` along the
  1847. * specified dimension `dim` and stores the result in the destination tensor
  1848. * `acl_dst`.
  1849. *
  1850. * @param ctx The context for the CANN backend operations.
  1851. * @param acl_src The source tensor on which the softmax function will be
  1852. * applied.
  1853. * @param dim The dimension along which the softmax function will be computed.
  1854. * @param acl_dst The destination tensor where the softmax results will be
  1855. * stored.
  1856. */
  1857. static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  1858. int64_t dim, aclTensor* acl_dst) {
  1859. uint64_t workspaceSize = 0;
  1860. aclOpExecutor* executor;
  1861. void* workspaceAddr = nullptr;
  1862. ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(acl_src, dim, acl_dst,
  1863. &workspaceSize, &executor));
  1864. if (workspaceSize > 0) {
  1865. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  1866. workspaceAddr = workspace_allocator.get();
  1867. }
  1868. aclrtStream stream = ctx.stream();
  1869. ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream));
  1870. }
  1871. void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  1872. ggml_tensor* src0 = dst->src[0];
  1873. ggml_tensor* src1 = dst->src[1]; // mask
  1874. aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
  1875. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  1876. float scale = 1.0f;
  1877. float max_bias = 0.0f;
  1878. memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
  1879. memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
  1880. // input mul scale
  1881. aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
  1882. size_t n_bytes = ggml_nbytes(src0);
  1883. ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
  1884. void* input_mul_scale_buffer = mul_scale_allocator.get();
  1885. aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
  1886. input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
  1887. src0->nb, GGML_MAX_DIMS);
  1888. bool inplace = false;
  1889. aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
  1890. // mask
  1891. aclTensor* acl_src1_fp32_tensor = nullptr;
  1892. aclTensor* tmp_mask_tensor = nullptr;
  1893. ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
  1894. if (src1) {
  1895. const bool use_f16 = src1->type == GGML_TYPE_F16;
  1896. if (use_f16) {
  1897. // cast to fp32
  1898. size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
  1899. size_t src1_fp32_nb[GGML_MAX_DIMS];
  1900. src1_fp32_nb[0] = sizeof(float_t);
  1901. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  1902. src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
  1903. }
  1904. src1_fp32_allocator.alloc(n_bytes);
  1905. void* src1_fp32_buffer = src1_fp32_allocator.get();
  1906. acl_src1_fp32_tensor = ggml_cann_create_tensor(
  1907. src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
  1908. src1_fp32_nb, GGML_MAX_DIMS);
  1909. aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
  1910. aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
  1911. ACL_CHECK(aclDestroyTensor(acl_src1));
  1912. } else {
  1913. acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
  1914. }
  1915. // broadcast the mask across rows, only use ne11 of ne01 in mask
  1916. if (src1->ne[1] != src0->ne[1]) {
  1917. // mask shape: [1,1,ne11,ne10]
  1918. int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
  1919. size_t tmp_mask_nb[GGML_MAX_DIMS];
  1920. tmp_mask_nb[0] = sizeof(float_t);
  1921. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  1922. tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
  1923. }
  1924. tmp_mask_tensor = ggml_cann_create_tensor(
  1925. src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
  1926. GGML_MAX_DIMS, ACL_FORMAT_ND);
  1927. }
  1928. // alibi
  1929. const int n_head = src0->ne[2];
  1930. const size_t src_nb0 = src0->nb[0];
  1931. n_bytes = ggml_nbytes(dst);
  1932. ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
  1933. void* output_buffer = output_allocator.get();
  1934. aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
  1935. output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
  1936. dst->nb, GGML_MAX_DIMS);
  1937. if (max_bias <= 0.0f) {
  1938. // slope = 1.0
  1939. if (tmp_mask_tensor) {
  1940. aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
  1941. alibi_output_tensor);
  1942. } else {
  1943. aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
  1944. alibi_output_tensor);
  1945. }
  1946. } else {
  1947. // slope != 1.0
  1948. if (tmp_mask_tensor) {
  1949. aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
  1950. alibi_output_tensor, n_head, src0->ne, src_nb0,
  1951. max_bias, dst);
  1952. } else {
  1953. aclnn_alibi(ctx, acl_input_mul_scale_tensor,
  1954. acl_src1_fp32_tensor, alibi_output_tensor, n_head,
  1955. src0->ne, src_nb0, max_bias, dst);
  1956. }
  1957. }
  1958. // softmax
  1959. aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
  1960. ACL_CHECK(aclDestroyTensor(alibi_output_tensor));
  1961. } else {
  1962. aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
  1963. }
  1964. ACL_CHECK(aclDestroyTensor(acl_src0));
  1965. ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor));
  1966. ACL_CHECK(aclDestroyTensor(acl_dst));
  1967. ACL_CHECK(aclDestroyScalar(acl_scale));
  1968. ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor));
  1969. ACL_CHECK(aclDestroyTensor(tmp_mask_tensor));
  1970. }
  1971. void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  1972. ggml_tensor* src0 = dst->src[0];
  1973. ggml_tensor* src1 = dst->src[1];
  1974. ggml_cann_pool_alloc src0_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
  1975. ggml_cann_pool_alloc src1_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
  1976. ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
  1977. src0->extra = src0_extra_allocator.get();
  1978. src1->extra = src1_extra_allocator.get();
  1979. dst->extra = dst_extra_allocator.get();
  1980. ACL_CHECK(aclrtMemcpyAsync(src0->extra, sizeof(ggml_tensor), src0,
  1981. sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
  1982. ctx.stream()));
  1983. ACL_CHECK(aclrtMemcpyAsync(src1->extra, sizeof(ggml_tensor), src1,
  1984. sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
  1985. ctx.stream()));
  1986. ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst,
  1987. sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
  1988. ctx.stream()));
  1989. switch (src0->type) {
  1990. case GGML_TYPE_F32:
  1991. aclrtlaunch_ascendc_get_row_f32(
  1992. 24, ctx.stream(), src0->data, src1->data, dst->data,
  1993. ((ggml_tensor*)src0->extra)->ne,
  1994. ((ggml_tensor*)src0->extra)->nb,
  1995. ((ggml_tensor*)src1->extra)->ne,
  1996. ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
  1997. ((ggml_tensor*)dst->extra)->nb);
  1998. break;
  1999. case GGML_TYPE_F16:
  2000. aclrtlaunch_ascendc_get_row_f16(
  2001. 24, ctx.stream(), src0->data, src1->data, dst->data,
  2002. ((ggml_tensor*)src0->extra)->ne,
  2003. ((ggml_tensor*)src0->extra)->nb,
  2004. ((ggml_tensor*)src1->extra)->ne,
  2005. ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
  2006. ((ggml_tensor*)dst->extra)->nb);
  2007. break;
  2008. case GGML_TYPE_Q4_0:
  2009. aclrtlaunch_ascendc_get_row_q4_0(
  2010. 24, ctx.stream(), src0->data, src1->data, dst->data,
  2011. ((ggml_tensor*)src0->extra)->ne,
  2012. ((ggml_tensor*)src1->extra)->ne,
  2013. ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
  2014. ((ggml_tensor*)dst->extra)->nb);
  2015. break;
  2016. case GGML_TYPE_Q8_0:
  2017. aclrtlaunch_ascendc_get_row_q8_0(
  2018. 24, ctx.stream(), src0->data, src1->data, dst->data,
  2019. ((ggml_tensor*)src0->extra)->ne,
  2020. ((ggml_tensor*)src1->extra)->ne,
  2021. ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
  2022. ((ggml_tensor*)dst->extra)->nb);
  2023. break;
  2024. default:
  2025. GGML_ABORT("fatal error");
  2026. break;
  2027. }
  2028. }
  2029. /**
  2030. * @brief Repeats elements of a tensor along a specified dimension.
  2031. *
  2032. * This function repeats each element of the source tensor `acl_src` a specified
  2033. * number of times (`repeats`) along the specified dimension `dim` and stores
  2034. * the result in the destination tensor `acl_dst`.
  2035. *
  2036. * @param ctx The context for the CANN backend operations.
  2037. * @param acl_src The source tensor whose elements will be repeated.
  2038. * @param acl_dst The destination tensor where the repeated elements will be
  2039. * stored.
  2040. * @param dim The dimension along which the elements will be repeated.
  2041. * @param repeats The number of times each element will be repeated.
  2042. * @param output_size The size of the output tensor.
  2043. */
  2044. static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx,
  2045. aclTensor* acl_src, aclTensor* acl_dst,
  2046. int64_t dim, int64_t repeats,
  2047. int64_t output_size) {
  2048. uint64_t workspaceSize = 0;
  2049. aclOpExecutor* executor;
  2050. void* workspaceAddr = nullptr;
  2051. ACL_CHECK(aclnnRepeatInterleaveIntWithDimGetWorkspaceSize(
  2052. acl_src, repeats, dim, output_size, acl_dst, &workspaceSize,
  2053. &executor));
  2054. if (workspaceSize > 0) {
  2055. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  2056. workspaceAddr = workspace_allocator.get();
  2057. }
  2058. ACL_CHECK(aclnnRepeatInterleaveIntWithDim(workspaceAddr, workspaceSize,
  2059. executor, ctx.stream()));
  2060. }
  2061. /**
  2062. * @brief Performs matrix multiplication of two tensors.
  2063. *
  2064. * This function computes the matrix multiplication of the input tensor
  2065. * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
  2066. * destination tensor `acl_dst`.
  2067. * The operation is defined as:
  2068. * \f[
  2069. * \text {acl_dst}=\text {acl_input@acl_weight}
  2070. * \f]
  2071. *
  2072. * @param ctx The context for the CANN backend operations.
  2073. * @param acl_input The input tensor for the matrix multiplication.
  2074. * @param acl_weight The weight tensor for the matrix multiplication.
  2075. * @param acl_dst The destination tensor where the result of the matrix
  2076. * multiplication will be stored.
  2077. */
  2078. static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
  2079. aclTensor* acl_weight, aclTensor* acl_dst) {
  2080. int8_t cube_math_type = 1; // ALLOW_FP32_DOWN_PRECISION, when input is
  2081. // fp32, atlas a2 will transpose it to HFLOAT32.
  2082. uint64_t workspaceSize = 0;
  2083. aclOpExecutor* executor;
  2084. void* workspaceAddr = nullptr;
  2085. ACL_CHECK(aclnnMatmulGetWorkspaceSize(acl_input, acl_weight, acl_dst,
  2086. cube_math_type, &workspaceSize,
  2087. &executor));
  2088. if (workspaceSize > 0) {
  2089. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  2090. workspaceAddr = workspace_allocator.get();
  2091. }
  2092. ACL_CHECK(
  2093. aclnnMatmul(workspaceAddr, workspaceSize, executor, ctx.stream()));
  2094. }
  2095. /**
  2096. * @brief Performs matrix multiplication with floating-point precision on
  2097. * tensors using the CANN backend.
  2098. *
  2099. * This function performs matrix multiplication of the input tensor and the
  2100. * weight tensor, handling broadcasting and transposing as needed, and stores
  2101. * the result in the destination tensor `dst`.
  2102. *
  2103. * @param ctx The context for the CANN backend operations.
  2104. * @param dst The destination tensor where the result of the matrix
  2105. * multiplication will be stored.
  2106. */
  2107. static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
  2108. ggml_tensor* dst) {
  2109. ggml_tensor* weight = dst->src[0]; // weight
  2110. ggml_tensor* input = dst->src[1]; // input
  2111. // when weight ne2 or ne3 is 1, aclnnMatmulGetWorkspaceSize will auto
  2112. // broadcast, when weight ne2 or ne3 is not 1, weight need repeat.
  2113. BCAST_MUL_MAT_SHAPE(input, weight, dst);
  2114. // transpose weight: [1,2,3,4] -> [1,2,4,3]
  2115. int64_t transpose_ne[] = {bcast_weight_ne[1], bcast_weight_ne[0],
  2116. bcast_weight_ne[2], bcast_weight_ne[3],
  2117. bcast_weight_ne[4], bcast_weight_ne[5]};
  2118. size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0],
  2119. bcast_weight_nb[2], bcast_weight_nb[3],
  2120. bcast_weight_nb[4], bcast_weight_nb[5]};
  2121. aclTensor* acl_weight_tensor =
  2122. ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, bcast_dims);
  2123. aclTensor* acl_input_tensor =
  2124. ggml_cann_create_tensor(input, BCAST_MUL_MAT_PARAM(input));
  2125. aclTensor* acl_dst = ggml_cann_create_tensor(dst, BCAST_MUL_MAT_PARAM(dst));
  2126. aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
  2127. ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
  2128. ACL_CHECK(aclDestroyTensor(acl_input_tensor));
  2129. ACL_CHECK(aclDestroyTensor(acl_dst));
  2130. }
  2131. /**
  2132. * @brief Performs matrix multiplication with quantized weights and
  2133. * floating-point inputs using the CANN backend.
  2134. *
  2135. * This function performs matrix multiplication of the input tensor `src1` and
  2136. * the weight tensor `src0`, handling broadcasting, transposing, and
  2137. * quantization as needed, and stores the result in the destination tensor
  2138. * `dst`.
  2139. *
  2140. * @param ctx The context for the CANN backend operations.
  2141. * @param dst The destination tensor where the result of the matrix
  2142. * multiplication will be stored.
  2143. */
  2144. static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
  2145. ggml_tensor* dst,
  2146. const enum ggml_type type) {
  2147. ggml_tensor* src0 = dst->src[0]; // weight
  2148. ggml_tensor* src1 = dst->src[1]; // input
  2149. // The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
  2150. // is regarded as batch. weight need transpose.
  2151. int64_t weight_ne[] = {src0->ne[1], src0->ne[0]};
  2152. float weight_elem_size;
  2153. if (type == GGML_TYPE_Q4_0) {
  2154. weight_elem_size = float(sizeof(uint8_t)) / 2;
  2155. }
  2156. else if (type == GGML_TYPE_Q8_0) {
  2157. weight_elem_size = float(sizeof(uint8_t));
  2158. }
  2159. else {
  2160. GGML_ABORT("Only support Q4_0 and Q8_0 MUL_MAT");
  2161. }
  2162. float weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
  2163. // size of one matrix is element_size * height * width.
  2164. size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1];
  2165. size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
  2166. // scale stored at the end of weight. Also need transpose.
  2167. GGML_ASSERT(QK4_0 == QK8_0);
  2168. int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0};
  2169. size_t scale_elem_size = sizeof(uint16_t);
  2170. size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size,
  2171. scale_elem_size};
  2172. size_t scale_stride = scale_elem_size * src0->ne[0] * src0->ne[1] / QK8_0;
  2173. char* scale_offset = (char*)src0->data + weight_size;
  2174. // input
  2175. void* input_buffer;
  2176. size_t input_elem_size = sizeof(uint16_t);
  2177. int64_t input_ne[] = {src1->ne[0], src1->ne[1]};
  2178. size_t input_nb[] = {input_elem_size, input_elem_size * src1->ne[0]};
  2179. size_t input_stride = input_elem_size * src1->ne[0] * src1->ne[1];
  2180. ggml_cann_pool_alloc input_alloctor(ctx.pool());
  2181. if (src1->type != GGML_TYPE_F16) {
  2182. aclTensor* acl_src1_tensor = ggml_cann_create_tensor(src1);
  2183. input_alloctor.alloc(ggml_nelements(src1) * input_elem_size);
  2184. input_buffer = input_alloctor.get();
  2185. int64_t* input_cast_ne = src1->ne;
  2186. size_t input_cast_nb[GGML_MAX_DIMS];
  2187. input_cast_nb[0] = sizeof(uint16_t);
  2188. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2189. input_cast_nb[i] = input_cast_nb[i - 1] * input_cast_ne[i - 1];
  2190. }
  2191. aclTensor* acl_input_tensor = ggml_cann_create_tensor(
  2192. input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
  2193. input_cast_nb, GGML_MAX_DIMS);
  2194. aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
  2195. ACL_CHECK(aclDestroyTensor(acl_input_tensor));
  2196. ACL_CHECK(aclDestroyTensor(acl_src1_tensor));
  2197. } else {
  2198. input_buffer = src1->data;
  2199. }
  2200. // output
  2201. size_t output_elem_size = sizeof(uint16_t);
  2202. int64_t output_ne[] = {dst->ne[0], dst->ne[1]};
  2203. size_t output_nb[] = {output_elem_size, output_elem_size * dst->ne[0]};
  2204. ggml_cann_pool_alloc output_alloctor(
  2205. ctx.pool(), ggml_nelements(dst) * output_elem_size);
  2206. void* output_buffer = output_alloctor.get();
  2207. size_t output_stride = output_elem_size * dst->ne[0] * dst->ne[1];
  2208. // aclnn
  2209. uint64_t workspaceSize = 0;
  2210. aclOpExecutor* executor;
  2211. void* workspaceAddr = nullptr;
  2212. for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) {
  2213. for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) {
  2214. int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]);
  2215. int64_t c0 = c1 / (src1->ne[2] / src0->ne[2]);
  2216. int64_t batch1 = n1 * src1->ne[2] + c1;
  2217. int64_t batch0 = n0 * src0->ne[2] + c0;
  2218. aclTensor* acl_input_tensor = ggml_cann_create_tensor(
  2219. (char*)input_buffer + batch1 * input_stride, ACL_FLOAT16,
  2220. input_elem_size, input_ne, input_nb, 2);
  2221. aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
  2222. (char*)src0->data + batch0 * weight_stride,
  2223. ggml_cann_type_mapping(type), weight_elem_size, weight_ne,
  2224. weight_nb, 2);
  2225. aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
  2226. scale_offset + batch0 * scale_stride, ACL_FLOAT16,
  2227. scale_elem_size, scale_ne, scale_nb, 2);
  2228. aclTensor* acl_output_tensor = ggml_cann_create_tensor(
  2229. (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
  2230. output_elem_size, output_ne, output_nb, 2);
  2231. ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
  2232. acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
  2233. nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
  2234. &workspaceSize, &executor));
  2235. if (workspaceSize > 0 && workspaceAddr == nullptr) {
  2236. ggml_cann_pool_alloc workspace_allocator(ctx.pool(),
  2237. workspaceSize);
  2238. workspaceAddr = workspace_allocator.get();
  2239. }
  2240. ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
  2241. workspaceAddr, workspaceSize, executor, ctx.stream()));
  2242. ACL_CHECK(aclDestroyTensor(acl_input_tensor));
  2243. ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
  2244. ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
  2245. ACL_CHECK(aclDestroyTensor(acl_output_tensor));
  2246. }
  2247. }
  2248. // cast out
  2249. int64_t* output_cast_ne = dst->ne;
  2250. size_t output_cast_nb[GGML_MAX_DIMS];
  2251. output_cast_nb[0] = sizeof(uint16_t);
  2252. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2253. output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];
  2254. }
  2255. aclTensor* acl_output_tensor =
  2256. ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, output_elem_size,
  2257. output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
  2258. aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
  2259. aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ACL_FLOAT);
  2260. ACL_CHECK(aclDestroyTensor(acl_output_tensor));
  2261. ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
  2262. }
  2263. void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  2264. const enum ggml_type type = dst->src[0]->type;
  2265. switch (type) {
  2266. case GGML_TYPE_F32:
  2267. case GGML_TYPE_F16:
  2268. ggml_cann_mat_mul_fp(ctx, dst);
  2269. break;
  2270. case GGML_TYPE_Q4_0:
  2271. case GGML_TYPE_Q8_0:
  2272. ggml_cann_mul_mat_quant(ctx, dst, type);
  2273. break;
  2274. default:
  2275. GGML_ABORT("fatal error");
  2276. break;
  2277. }
  2278. }
  2279. /**
  2280. * @brief Rolls the elements of a tensor along a specified dimension.
  2281. *
  2282. * This function rolls the elements of the source tensor `acl_src` by the
  2283. * specified shifts `shifts` along the specified dimensions `dims`, and stores
  2284. * the result in the destination tensor `acl_dst`.
  2285. *
  2286. * @param ctx The context for the CANN backend operations.
  2287. * @param acl_src The source tensor whose elements will be rolled.
  2288. * @param acl_dst The destination tensor where the rolled elements will be
  2289. * stored.
  2290. * @param shifts An array specifying the number of positions by which elements
  2291. * are shifted.
  2292. * @param dims An array specifying the dimensions along which elements are
  2293. * shifted.
  2294. */
  2295. static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src,
  2296. aclTensor* acl_dst, int64_t* shifts, int64_t* dims) {
  2297. aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1);
  2298. aclIntArray* acl_dims = aclCreateIntArray(dims, 1);
  2299. uint64_t workspaceSize = 0;
  2300. aclOpExecutor* executor;
  2301. void* workspaceAddr = nullptr;
  2302. ACL_CHECK(aclnnRollGetWorkspaceSize(acl_src, acl_shifts, acl_dims, acl_dst,
  2303. &workspaceSize, &executor));
  2304. if (workspaceSize > 0) {
  2305. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  2306. workspaceAddr = workspace_allocator.get();
  2307. }
  2308. ACL_CHECK(aclnnRoll(workspaceAddr, workspaceSize, executor, ctx.stream()));
  2309. ACL_CHECK(aclDestroyIntArray(acl_shifts));
  2310. ACL_CHECK(aclDestroyIntArray(acl_dims));
  2311. }
  2312. /**
  2313. * @brief Fills specified positions of a tensor with a scalar value.
  2314. *
  2315. * This function fills the positions in the source tensor `acl_src` specified by
  2316. * `index` along the dimension `dim` with the scalar value `value`.
  2317. *
  2318. * @param ctx The context for the CANN backend operations.
  2319. * @param acl_src The source tensor where the positions will be filled.
  2320. * @param dim The dimension along which the positions are specified.
  2321. * @param index An array specifying the positions to be filled.
  2322. * @param index_num The number of positions specified in the index array.
  2323. * @param value The scalar value used to fill the specified positions.
  2324. */
  2325. static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
  2326. aclTensor* acl_src, int64_t dim,
  2327. int64_t* index, int64_t index_num,
  2328. float value) {
  2329. aclIntArray* acl_index = aclCreateIntArray(index, index_num);
  2330. aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
  2331. uint64_t workspaceSize = 0;
  2332. aclOpExecutor* executor;
  2333. void* workspaceAddr = nullptr;
  2334. ACL_CHECK(aclnnInplaceIndexFillTensorGetWorkspaceSize(
  2335. acl_src, dim, acl_index, acl_value, &workspaceSize, &executor));
  2336. if (workspaceSize > 0) {
  2337. ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
  2338. workspaceAddr = workspace_allocator.get();
  2339. }
  2340. ACL_CHECK(aclnnInplaceIndexFillTensor(workspaceAddr, workspaceSize,
  2341. executor, ctx.stream()));
  2342. ACL_CHECK(aclDestroyIntArray(acl_index));
  2343. ACL_CHECK(aclDestroyScalar(acl_value));
  2344. }
  2345. static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
  2346. aclTensor* acl_cos_repeat_tensor,
  2347. aclTensor* acl_sin_repeat_tensor,
  2348. float theta_scale, bool is_neox) {
  2349. // int sin/cos cache, cache has different repeat method depond on
  2350. // @param.is_neox
  2351. ggml_tensor* src0 = dst->src[0]; // input
  2352. ggml_tensor* src1 = dst->src[1]; // position
  2353. // arange, [0,1,...,ne0/2]
  2354. int64_t arange_length = src0->ne[0] / 2;
  2355. ggml_cann_pool_alloc arange_allocator(ctx.pool(),
  2356. arange_length * sizeof(float_t));
  2357. void* arange_buffer = arange_allocator.get();
  2358. int64_t arange_ne[] = {arange_length, 1, 1, 1};
  2359. size_t arange_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
  2360. arange_length * sizeof(float_t)};
  2361. aclTensor* acl_arange_tensor =
  2362. ggml_cann_create_tensor(arange_buffer, ACL_FLOAT, sizeof(float_t),
  2363. arange_ne, arange_nb, GGML_MAX_DIMS);
  2364. float start = 0;
  2365. float step = 1;
  2366. float stop = src0->ne[0] / 2;
  2367. float n_elements = src0->ne[0] / 2;
  2368. aclnn_arange(ctx, acl_arange_tensor, start, stop, step, n_elements);
  2369. // power
  2370. // aclnnPowScalarTensor(): @param self is tensor which should be scalar, so
  2371. // use aclnn_pow_tensor_tensor() until fixed. aclScalar* acl_theta_scale =
  2372. // aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
  2373. // aclnn_power_scalar_tensor(ctx, acl_theta_scale, acl_arange_tensor,
  2374. // acl_power_tensor);
  2375. ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
  2376. arange_length * sizeof(float_t));
  2377. void* theta_scale_buffer = theta_scale_allocator.get();
  2378. aclTensor* acl_theta_scale_tensor = aclnn_ones(
  2379. ctx, theta_scale_buffer, arange_length * sizeof(float_t), arange_ne,
  2380. GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), theta_scale);
  2381. aclnn_pow_tensor_tensor(ctx, acl_theta_scale_tensor, acl_arange_tensor);
  2382. // position
  2383. GGML_ASSERT(src1->type == GGML_TYPE_I32);
  2384. int64_t position_length = src1->ne[0];
  2385. int64_t position_ne[] = {1, position_length, 1, 1};
  2386. size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t),
  2387. sizeof(int32_t) * position_length,
  2388. sizeof(int32_t) * position_length};
  2389. aclTensor* acl_position_tensor = ggml_cann_create_tensor(
  2390. src1->data, ggml_cann_type_mapping(src1->type),
  2391. ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
  2392. // power * position
  2393. int64_t theta_length = arange_length * position_length;
  2394. ggml_cann_pool_alloc theta_allocator(ctx.pool(),
  2395. theta_length * sizeof(float_t));
  2396. void* theta_buffer = theta_allocator.get();
  2397. int64_t theta_ne[] = {arange_length, position_length, 1, 1};
  2398. size_t theta_nb[GGML_MAX_DIMS];
  2399. theta_nb[0] = sizeof(float_t);
  2400. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2401. theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
  2402. }
  2403. aclTensor* acl_theta_tensor =
  2404. ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
  2405. theta_ne, theta_nb, GGML_MAX_DIMS);
  2406. aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
  2407. acl_theta_tensor);
  2408. // permute: [0,1,2,3]->[0,2,1,3]
  2409. int64_t permute_ne[] = {arange_length, 1, position_length, 1};
  2410. size_t permute_nb[GGML_MAX_DIMS];
  2411. permute_nb[0] = sizeof(float_t);
  2412. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2413. permute_nb[i] = permute_nb[i - 1] * permute_ne[i - 1];
  2414. }
  2415. ggml_cann_pool_alloc permute_allocator(ctx.pool(),
  2416. theta_length * sizeof(float_t));
  2417. void* permute_buffer = permute_allocator.get();
  2418. aclTensor* acl_permute_tensor = ggml_cann_create_tensor(
  2419. permute_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
  2420. GGML_MAX_DIMS, ACL_FORMAT_ND);
  2421. int64_t permute_dim[] = {0, 2, 1, 3};
  2422. int64_t num_dims = 4;
  2423. aclnn_permute(ctx, acl_theta_tensor, acl_permute_tensor, permute_dim,
  2424. num_dims);
  2425. // sin/cos
  2426. ggml_cann_pool_alloc sin_allocator(ctx.pool(),
  2427. theta_length * sizeof(float_t));
  2428. void* sin_buffer = sin_allocator.get();
  2429. aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
  2430. sin_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
  2431. GGML_MAX_DIMS, ACL_FORMAT_ND);
  2432. aclnn_sin(ctx, acl_permute_tensor, acl_sin_tensor);
  2433. ggml_cann_pool_alloc cos_allocator(ctx.pool(),
  2434. theta_length * sizeof(float_t));
  2435. void* cos_buffer = cos_allocator.get();
  2436. aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
  2437. cos_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
  2438. GGML_MAX_DIMS, ACL_FORMAT_ND);
  2439. aclnn_cos(ctx, acl_permute_tensor, acl_cos_tensor);
  2440. // repeat
  2441. if (is_neox) {
  2442. int64_t repeatsArray[] = {1, 1, 1, 2};
  2443. aclnn_repeat(ctx, acl_sin_tensor, acl_sin_repeat_tensor, repeatsArray);
  2444. aclnn_repeat(ctx, acl_cos_tensor, acl_cos_repeat_tensor, repeatsArray);
  2445. } else {
  2446. int64_t num_repeats = 2;
  2447. int64_t dim = 3;
  2448. int64_t output_size = arange_length * num_repeats;
  2449. aclnn_repeat_interleave(ctx, acl_sin_tensor, acl_sin_repeat_tensor, dim,
  2450. num_repeats, output_size);
  2451. aclnn_repeat_interleave(ctx, acl_cos_tensor, acl_cos_repeat_tensor, dim,
  2452. num_repeats, output_size);
  2453. }
  2454. // release
  2455. ACL_CHECK(aclDestroyTensor(acl_arange_tensor));
  2456. ACL_CHECK(aclDestroyTensor(acl_theta_scale_tensor));
  2457. ACL_CHECK(aclDestroyTensor(acl_position_tensor));
  2458. ACL_CHECK(aclDestroyTensor(acl_theta_tensor));
  2459. ACL_CHECK(aclDestroyTensor(acl_permute_tensor));
  2460. ACL_CHECK(aclDestroyTensor(acl_sin_tensor));
  2461. ACL_CHECK(aclDestroyTensor(acl_cos_tensor));
  2462. }
  2463. void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  2464. // TODO: use ascendc
  2465. // Only test with LLAMA model.
  2466. ggml_tensor* src0 = dst->src[0]; // input
  2467. ggml_tensor* src2 = dst->src[2]; // freq_factors
  2468. // TODO: with freq_factors
  2469. GGML_ASSERT(src2 == NULL);
  2470. // param
  2471. float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
  2472. // const int n_past = ((int32_t *) dst->op_params)[0];
  2473. const int n_dims = ((int32_t*)dst->op_params)[1];
  2474. const int mode = ((int32_t*)dst->op_params)[2];
  2475. // const int n_ctx = ((int32_t *) dst->op_params)[3];
  2476. const int n_ctx_orig = ((int32_t*)dst->op_params)[4];
  2477. GGML_TENSOR_UNARY_OP_LOCALS
  2478. memcpy(&freq_base, (int32_t*)dst->op_params + 5, sizeof(float));
  2479. memcpy(&freq_scale, (int32_t*)dst->op_params + 6, sizeof(float));
  2480. memcpy(&ext_factor, (int32_t*)dst->op_params + 7, sizeof(float));
  2481. memcpy(&attn_factor, (int32_t*)dst->op_params + 8, sizeof(float));
  2482. memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float));
  2483. memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float));
  2484. GGML_ASSERT(n_dims <= ne0);
  2485. GGML_ASSERT(n_dims % 2 == 0);
  2486. // TODO: ext_factor != 0
  2487. GGML_ASSERT(ext_factor == 0);
  2488. // TODO: freq_scale != 1
  2489. GGML_ASSERT(freq_scale == 1);
  2490. const float theta_scale = powf(freq_base, -2.0f / n_dims);
  2491. float corr_dims[2];
  2492. ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
  2493. beta_slow, corr_dims);
  2494. const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
  2495. // init cos/sin cache
  2496. ggml_cann_pool_alloc sin_allocator(
  2497. ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t));
  2498. ggml_cann_pool_alloc cos_allocator(
  2499. ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t));
  2500. void* sin_buffer = sin_allocator.get();
  2501. void* cos_buffer = cos_allocator.get();
  2502. int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
  2503. size_t sin_reshape_nb[GGML_MAX_DIMS];
  2504. sin_reshape_nb[0] = sizeof(float_t);
  2505. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2506. sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
  2507. }
  2508. aclTensor* acl_sin_reshape_tensor =
  2509. ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
  2510. sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
  2511. aclTensor* acl_cos_reshape_tensor =
  2512. ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
  2513. sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
  2514. aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
  2515. theta_scale, is_neox);
  2516. // roll input
  2517. void* input_roll_buffer;
  2518. aclTensor* acl_minus_one_tensor;
  2519. void* minus_one_scale_buffer = nullptr;
  2520. ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0));
  2521. ggml_cann_pool_alloc minus_one_scale_allocator(
  2522. ctx.pool(), sizeof(float_t) * src0->ne[0]);
  2523. if (!is_neox) {
  2524. // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
  2525. input_roll_buffer = roll_allocator.get();
  2526. int64_t input_roll_ne[4] = {2, src0->ne[1] * (src0->ne[0] / 2),
  2527. src0->ne[2], src0->ne[3]};
  2528. size_t input_roll_nb[GGML_MAX_DIMS];
  2529. input_roll_nb[0] = ggml_type_size(src0->type);
  2530. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2531. input_roll_nb[i] = input_roll_nb[i - 1] * input_roll_ne[i - 1];
  2532. }
  2533. aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor(
  2534. input_roll_buffer, ggml_cann_type_mapping(src0->type),
  2535. ggml_type_size(src0->type), input_roll_ne, input_roll_nb,
  2536. GGML_MAX_DIMS);
  2537. aclTensor* acl_input_tensor = ggml_cann_create_tensor(
  2538. src0->data, ggml_cann_type_mapping(src0->type),
  2539. ggml_type_size(src0->type), input_roll_ne, input_roll_nb,
  2540. GGML_MAX_DIMS);
  2541. int64_t shifts[] = {1};
  2542. int64_t dims[] = {3};
  2543. aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
  2544. ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
  2545. ACL_CHECK(aclDestroyTensor(acl_input_tensor));
  2546. // init [-1, 1, -1, 1, ...]
  2547. minus_one_scale_buffer = minus_one_scale_allocator.get();
  2548. int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
  2549. size_t minus_one_nb[GGML_MAX_DIMS];
  2550. minus_one_nb[0] = sizeof(float_t);
  2551. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2552. minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
  2553. }
  2554. acl_minus_one_tensor = aclnn_ones(
  2555. ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
  2556. minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
  2557. int64_t dim = 3;
  2558. int64_t* index = new int64_t[src0->ne[0]];
  2559. for (int i = 0; i < src0->ne[0]; i++) {
  2560. index[i] = i / 2 * 2;
  2561. }
  2562. int64_t index_num = src0->ne[0];
  2563. float value = -1;
  2564. aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index,
  2565. index_num, value);
  2566. } else {
  2567. // roll input: [q0,q1,q2,...] ->
  2568. // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]
  2569. input_roll_buffer = roll_allocator.get();
  2570. aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor(
  2571. input_roll_buffer, ggml_cann_type_mapping(src0->type),
  2572. ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
  2573. aclTensor* acl_input_tensor = ggml_cann_create_tensor(src0);
  2574. int64_t shifts[] = {src0->ne[0] / 2};
  2575. int64_t dims[] = {3};
  2576. aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
  2577. ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
  2578. ACL_CHECK(aclDestroyTensor(acl_input_tensor));
  2579. // init [-1, -1, -1, 1, 1,1,...]
  2580. minus_one_scale_buffer = minus_one_scale_allocator.get();
  2581. int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
  2582. size_t minus_one_nb[GGML_MAX_DIMS];
  2583. minus_one_nb[0] = sizeof(float_t);
  2584. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2585. minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
  2586. }
  2587. acl_minus_one_tensor = aclnn_ones(
  2588. ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
  2589. minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
  2590. // -1 * first half
  2591. int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1};
  2592. size_t first_half_nb[GGML_MAX_DIMS];
  2593. first_half_nb[0] = sizeof(float_t);
  2594. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2595. first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1];
  2596. }
  2597. aclTensor* acl_first_half_tensor = ggml_cann_create_tensor(
  2598. minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne,
  2599. first_half_nb, GGML_MAX_DIMS);
  2600. bool inplace = true;
  2601. float scale = -1;
  2602. aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace);
  2603. ACL_CHECK(aclDestroyTensor(acl_first_half_tensor));
  2604. }
  2605. // TODO: n_dims < ne0
  2606. GGML_ASSERT(n_dims == src0->ne[0]);
  2607. // input * scale
  2608. ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(),
  2609. ggml_nbytes(src0));
  2610. void* input_roll_mul_scale_buffer = roll_mul_scale_allocator.get();
  2611. size_t input_nb[GGML_MAX_DIMS];
  2612. input_nb[0] = ggml_type_size(src0->type);
  2613. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2614. input_nb[i] = input_nb[i - 1] * src0->ne[i - 1];
  2615. }
  2616. aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor(
  2617. input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type),
  2618. ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS);
  2619. aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor(
  2620. input_roll_buffer, ggml_cann_type_mapping(src0->type),
  2621. ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS);
  2622. aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor,
  2623. acl_input_roll_mul_scale_tensor);
  2624. // output
  2625. aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
  2626. aclTensor* acl_dst = ggml_cann_create_tensor(dst);
  2627. void* output_fp32_buffer;
  2628. if (src0->type == GGML_TYPE_F32) {
  2629. aclnn_inplace_mul(ctx, acl_src0, acl_cos_reshape_tensor);
  2630. aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor,
  2631. acl_sin_reshape_tensor);
  2632. aclnn_add(ctx, acl_src0, acl_input_roll_mul_scale_tensor, acl_dst);
  2633. // TODO: ne0 != n_dims in mode2
  2634. } else if (src0->type == GGML_TYPE_F16) {
  2635. size_t input_fp32_nb[GGML_MAX_DIMS];
  2636. input_fp32_nb[0] = sizeof(float_t);
  2637. for (int i = 1; i < GGML_MAX_DIMS; i++) {
  2638. input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1];
  2639. }
  2640. ggml_cann_pool_alloc fp32_allocator1(
  2641. ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
  2642. void* input_fp32_buffer1 = fp32_allocator1.get();
  2643. aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor(
  2644. input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne,
  2645. input_fp32_nb, GGML_MAX_DIMS);
  2646. ggml_cann_pool_alloc fp32_allocator2(
  2647. ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
  2648. void* input_fp32_buffer2 = fp32_allocator2.get();
  2649. aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor(
  2650. input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne,
  2651. input_fp32_nb, GGML_MAX_DIMS);
  2652. ggml_cann_pool_alloc fp32_allocator(
  2653. ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
  2654. output_fp32_buffer = fp32_allocator.get();
  2655. aclTensor* output_fp32_tensor = ggml_cann_create_tensor(
  2656. output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne,
  2657. input_fp32_nb, GGML_MAX_DIMS);
  2658. aclnn_mul(ctx, acl_src0, acl_cos_reshape_tensor, input_fp32_tensor1);
  2659. aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
  2660. input_fp32_tensor2);
  2661. aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2,
  2662. output_fp32_tensor);
  2663. aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
  2664. ACL_CHECK(aclDestroyTensor(input_fp32_tensor1));
  2665. ACL_CHECK(aclDestroyTensor(input_fp32_tensor2));
  2666. ACL_CHECK(aclDestroyTensor(output_fp32_tensor));
  2667. }
  2668. ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
  2669. ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
  2670. ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor));
  2671. ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor));
  2672. ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor));
  2673. ACL_CHECK(aclDestroyTensor(acl_src0));
  2674. ACL_CHECK(aclDestroyTensor(acl_dst));
  2675. }