ggml.h 91 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492
  1. #pragma once
  2. //
  3. // GGML Tensor Library
  4. //
  5. // This documentation is still a work in progress.
  6. // If you wish some specific topics to be covered, feel free to drop a comment:
  7. //
  8. // https://github.com/ggerganov/whisper.cpp/issues/40
  9. //
  10. // ## Overview
  11. //
  12. // This library implements:
  13. //
  14. // - a set of tensor operations
  15. // - automatic differentiation
  16. // - basic optimization algorithms
  17. //
  18. // The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
  19. // but is not limited to, the following:
  20. //
  21. // - linear regression
  22. // - support vector machines
  23. // - neural networks
  24. //
  25. // The library allows the user to define a certain function using the available tensor operations. This function
  26. // definition is represented internally via a computation graph. Each tensor operation in the function definition
  27. // corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
  28. // function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
  29. // using one of the available optimization algorithms.
  30. //
  31. // For example, here we define the function: f(x) = a*x^2 + b
  32. //
  33. // {
  34. // struct ggml_init_params params = {
  35. // .mem_size = 16*1024*1024,
  36. // .mem_buffer = NULL,
  37. // };
  38. //
  39. // // memory allocation happens here
  40. // struct ggml_context * ctx = ggml_init(params);
  41. //
  42. // struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  43. //
  44. // ggml_set_param(ctx, x); // x is an input variable
  45. //
  46. // struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  47. // struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  48. // struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
  49. // struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
  50. //
  51. // ...
  52. // }
  53. //
  54. // Notice that the function definition above does not involve any actual computation. The computation is performed only
  55. // when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
  56. //
  57. // {
  58. // ...
  59. //
  60. // struct ggml_cgraph * gf = ggml_new_graph(ctx);
  61. // ggml_build_forward_expand(gf, f);
  62. //
  63. // // set the input variable and parameter values
  64. // ggml_set_f32(x, 2.0f);
  65. // ggml_set_f32(a, 3.0f);
  66. // ggml_set_f32(b, 4.0f);
  67. //
  68. // ggml_graph_compute_with_ctx(ctx, &gf, n_threads);
  69. //
  70. // printf("f = %f\n", ggml_get_f32_1d(f, 0));
  71. //
  72. // ...
  73. // }
  74. //
  75. // The actual computation is performed in the ggml_graph_compute() function.
  76. //
  77. // The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
  78. // ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
  79. // in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
  80. // and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
  81. // actually needed.
  82. //
  83. // The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
  84. // differentiation and optimization algorithms.
  85. //
  86. // The described approach allows to define the function graph once and then compute its forward or backward graphs
  87. // multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
  88. // the user can avoid the memory allocation overhead at runtime.
  89. //
  90. // The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
  91. // citizens, but in theory the library can be extended to support FP8 and integer data types.
  92. //
  93. // Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
  94. // and binary operations. Most of the available operations fall into one of these two categories. With time, it became
  95. // clear that the library needs to support more complex operations. The way to support these operations is not clear
  96. // yet, but a few examples are demonstrated in the following operations:
  97. //
  98. // - ggml_permute()
  99. // - ggml_conv_1d_1s()
  100. // - ggml_conv_1d_2s()
  101. //
  102. // For each tensor operator, the library implements a forward and backward computation function. The forward function
  103. // computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
  104. // input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
  105. // calculus class, or watch the following video:
  106. //
  107. // What is Automatic Differentiation?
  108. // https://www.youtube.com/watch?v=wG_nF1awSSY
  109. //
  110. //
  111. // ## Tensor data (struct ggml_tensor)
  112. //
  113. // The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
  114. // the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
  115. // pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
  116. //
  117. // {
  118. // struct ggml_tensor * c = ggml_add(ctx, a, b);
  119. //
  120. // assert(c->src[0] == a);
  121. // assert(c->src[1] == b);
  122. // }
  123. //
  124. // The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
  125. // number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
  126. // to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
  127. // permutation. All tensor operations have to take the stride into account and not assume that the tensor is
  128. // contiguous in memory.
  129. //
  130. // The data of the tensor is accessed via the "data" pointer. For example:
  131. //
  132. // {
  133. // const int nx = 2;
  134. // const int ny = 3;
  135. //
  136. // struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny);
  137. //
  138. // for (int y = 0; y < ny; y++) {
  139. // for (int x = 0; x < nx; x++) {
  140. // *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y;
  141. // }
  142. // }
  143. //
  144. // ...
  145. // }
  146. //
  147. // Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
  148. //
  149. // ## The matrix multiplication operator (ggml_mul_mat)
  150. //
  151. // TODO
  152. //
  153. //
  154. // ## Multi-threading
  155. //
  156. // TODO
  157. //
  158. //
  159. // ## Overview of ggml.c
  160. //
  161. // TODO
  162. //
  163. //
  164. // ## SIMD optimizations
  165. //
  166. // TODO
  167. //
  168. //
  169. // ## Debugging ggml
  170. //
  171. // TODO
  172. //
  173. //
  174. #ifdef GGML_SHARED
  175. # if defined(_WIN32) && !defined(__MINGW32__)
  176. # ifdef GGML_BUILD
  177. # define GGML_API __declspec(dllexport) extern
  178. # else
  179. # define GGML_API __declspec(dllimport) extern
  180. # endif
  181. # else
  182. # define GGML_API __attribute__ ((visibility ("default"))) extern
  183. # endif
  184. #else
  185. # define GGML_API extern
  186. #endif
  187. // TODO: support for clang
  188. #ifdef __GNUC__
  189. # define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
  190. #elif defined(_MSC_VER)
  191. # define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func
  192. #else
  193. # define GGML_DEPRECATED(func, hint) func
  194. #endif
  195. #ifndef __GNUC__
  196. # define GGML_ATTRIBUTE_FORMAT(...)
  197. #elif defined(__MINGW32__) && !defined(__clang__)
  198. # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
  199. #else
  200. # define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
  201. #endif
  202. #include <stdbool.h>
  203. #include <stddef.h>
  204. #include <stdint.h>
  205. #include <stdio.h>
  206. #define GGML_FILE_MAGIC 0x67676d6c // "ggml"
  207. #define GGML_FILE_VERSION 2
  208. #define GGML_QNT_VERSION 2 // bump this on quantization format changes
  209. #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
  210. #define GGML_MAX_DIMS 4
  211. #define GGML_MAX_PARAMS 2048
  212. #define GGML_MAX_SRC 10
  213. #define GGML_MAX_N_THREADS 512
  214. #define GGML_MAX_OP_PARAMS 64
  215. #ifndef GGML_MAX_NAME
  216. # define GGML_MAX_NAME 64
  217. #endif
  218. #define GGML_DEFAULT_N_THREADS 4
  219. #define GGML_DEFAULT_GRAPH_SIZE 2048
  220. #if UINTPTR_MAX == 0xFFFFFFFF
  221. #define GGML_MEM_ALIGN 4
  222. #else
  223. #define GGML_MEM_ALIGN 16
  224. #endif
  225. #define GGML_EXIT_SUCCESS 0
  226. #define GGML_EXIT_ABORTED 1
  227. #define GGML_ROPE_TYPE_NEOX 2
  228. #define GGML_ROPE_TYPE_MROPE 8
  229. #define GGML_ROPE_TYPE_VISION 24
  230. #define GGML_MROPE_SECTIONS 4
  231. #define GGML_UNUSED(x) (void)(x)
  232. #ifdef __CUDACC__
  233. template<typename... Args>
  234. __host__ __device__ constexpr inline void ggml_unused_vars_impl(Args&&...) noexcept {}
  235. #define GGML_UNUSED_VARS(...) ggml_unused_vars_impl(__VA_ARGS__)
  236. #else
  237. #define GGML_UNUSED_VARS(...) do { (void)sizeof((__VA_ARGS__, 0)); } while(0)
  238. #endif // __CUDACC__
  239. #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1))
  240. #ifndef NDEBUG
  241. # define GGML_UNREACHABLE() do { fprintf(stderr, "statement should be unreachable\n"); abort(); } while(0)
  242. #elif defined(__GNUC__)
  243. # define GGML_UNREACHABLE() __builtin_unreachable()
  244. #elif defined(_MSC_VER)
  245. # define GGML_UNREACHABLE() __assume(0)
  246. #else
  247. # define GGML_UNREACHABLE() ((void) 0)
  248. #endif
  249. #ifdef __cplusplus
  250. # define GGML_NORETURN [[noreturn]]
  251. #elif defined(_MSC_VER)
  252. # define GGML_NORETURN __declspec(noreturn)
  253. #else
  254. # define GGML_NORETURN _Noreturn
  255. #endif
  256. #define GGML_ABORT(...) ggml_abort(__FILE__, __LINE__, __VA_ARGS__)
  257. #define GGML_ASSERT(x) if (!(x)) GGML_ABORT("GGML_ASSERT(%s) failed", #x)
  258. // used to copy the number of elements and stride in bytes of tensors into local variables.
  259. // main purpose is to reduce code duplication and improve readability.
  260. //
  261. // example:
  262. //
  263. // GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
  264. // GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
  265. //
  266. #define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
  267. const type prefix##0 = (pointer)->array[0]; \
  268. GGML_UNUSED(prefix##0);
  269. #define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
  270. GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
  271. const type prefix##1 = (pointer)->array[1]; \
  272. GGML_UNUSED(prefix##1);
  273. #define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
  274. GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
  275. const type prefix##2 = (pointer)->array[2]; \
  276. GGML_UNUSED(prefix##2);
  277. #define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
  278. GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
  279. const type prefix##3 = (pointer)->array[3]; \
  280. GGML_UNUSED(prefix##3);
  281. #define GGML_TENSOR_UNARY_OP_LOCALS \
  282. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
  283. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
  284. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
  285. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  286. #define GGML_TENSOR_BINARY_OP_LOCALS \
  287. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
  288. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
  289. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
  290. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
  291. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
  292. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  293. #define GGML_TENSOR_TERNARY_OP_LOCALS \
  294. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
  295. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
  296. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
  297. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
  298. GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
  299. GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
  300. GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
  301. GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
  302. #define GGML_TENSOR_BINARY_OP_LOCALS01 \
  303. GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
  304. GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
  305. GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
  306. GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
  307. #ifdef __cplusplus
  308. extern "C" {
  309. #endif
  310. // Function type used in fatal error callbacks
  311. typedef void (*ggml_abort_callback_t)(const char * error_message);
  312. // Set the abort callback (passing null will restore original abort functionality: printing a message to stdout)
  313. // Returns the old callback for chaining
  314. GGML_API ggml_abort_callback_t ggml_set_abort_callback(ggml_abort_callback_t callback);
  315. GGML_NORETURN GGML_ATTRIBUTE_FORMAT(3, 4)
  316. GGML_API void ggml_abort(const char * file, int line, const char * fmt, ...);
  317. enum ggml_status {
  318. GGML_STATUS_ALLOC_FAILED = -2,
  319. GGML_STATUS_FAILED = -1,
  320. GGML_STATUS_SUCCESS = 0,
  321. GGML_STATUS_ABORTED = 1,
  322. };
  323. // get ggml_status name string
  324. GGML_API const char * ggml_status_to_string(enum ggml_status status);
  325. // ieee 754-2008 half-precision float16
  326. // todo: make this not an integral type
  327. typedef uint16_t ggml_fp16_t;
  328. GGML_API float ggml_fp16_to_fp32(ggml_fp16_t);
  329. GGML_API ggml_fp16_t ggml_fp32_to_fp16(float);
  330. GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t *, float *, int64_t);
  331. GGML_API void ggml_fp32_to_fp16_row(const float *, ggml_fp16_t *, int64_t);
  332. // google brain half-precision bfloat16
  333. typedef struct { uint16_t bits; } ggml_bf16_t;
  334. GGML_API ggml_bf16_t ggml_fp32_to_bf16(float);
  335. GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16
  336. GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t);
  337. GGML_API void ggml_fp32_to_bf16_row_ref(const float *, ggml_bf16_t *, int64_t);
  338. GGML_API void ggml_fp32_to_bf16_row(const float *, ggml_bf16_t *, int64_t);
  339. struct ggml_object;
  340. struct ggml_context;
  341. struct ggml_cgraph;
  342. // NOTE: always add types at the end of the enum to keep backward compatibility
  343. enum ggml_type {
  344. GGML_TYPE_F32 = 0,
  345. GGML_TYPE_F16 = 1,
  346. GGML_TYPE_Q4_0 = 2,
  347. GGML_TYPE_Q4_1 = 3,
  348. // GGML_TYPE_Q4_2 = 4, support has been removed
  349. // GGML_TYPE_Q4_3 = 5, support has been removed
  350. GGML_TYPE_Q5_0 = 6,
  351. GGML_TYPE_Q5_1 = 7,
  352. GGML_TYPE_Q8_0 = 8,
  353. GGML_TYPE_Q8_1 = 9,
  354. GGML_TYPE_Q2_K = 10,
  355. GGML_TYPE_Q3_K = 11,
  356. GGML_TYPE_Q4_K = 12,
  357. GGML_TYPE_Q5_K = 13,
  358. GGML_TYPE_Q6_K = 14,
  359. GGML_TYPE_Q8_K = 15,
  360. GGML_TYPE_IQ2_XXS = 16,
  361. GGML_TYPE_IQ2_XS = 17,
  362. GGML_TYPE_IQ3_XXS = 18,
  363. GGML_TYPE_IQ1_S = 19,
  364. GGML_TYPE_IQ4_NL = 20,
  365. GGML_TYPE_IQ3_S = 21,
  366. GGML_TYPE_IQ2_S = 22,
  367. GGML_TYPE_IQ4_XS = 23,
  368. GGML_TYPE_I8 = 24,
  369. GGML_TYPE_I16 = 25,
  370. GGML_TYPE_I32 = 26,
  371. GGML_TYPE_I64 = 27,
  372. GGML_TYPE_F64 = 28,
  373. GGML_TYPE_IQ1_M = 29,
  374. GGML_TYPE_BF16 = 30,
  375. // GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files
  376. // GGML_TYPE_Q4_0_4_8 = 32,
  377. // GGML_TYPE_Q4_0_8_8 = 33,
  378. GGML_TYPE_TQ1_0 = 34,
  379. GGML_TYPE_TQ2_0 = 35,
  380. // GGML_TYPE_IQ4_NL_4_4 = 36,
  381. // GGML_TYPE_IQ4_NL_4_8 = 37,
  382. // GGML_TYPE_IQ4_NL_8_8 = 38,
  383. GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
  384. GGML_TYPE_COUNT = 40,
  385. };
  386. // precision
  387. enum ggml_prec {
  388. GGML_PREC_DEFAULT = 0, // stored as ggml_tensor.op_params, 0 by default
  389. GGML_PREC_F32 = 10,
  390. };
  391. // model file types
  392. enum ggml_ftype {
  393. GGML_FTYPE_UNKNOWN = -1,
  394. GGML_FTYPE_ALL_F32 = 0,
  395. GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
  396. GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
  397. GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
  398. GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
  399. GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
  400. GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
  401. GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
  402. GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
  403. GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
  404. GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
  405. GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
  406. GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
  407. GGML_FTYPE_MOSTLY_IQ2_XXS = 15, // except 1d tensors
  408. GGML_FTYPE_MOSTLY_IQ2_XS = 16, // except 1d tensors
  409. GGML_FTYPE_MOSTLY_IQ3_XXS = 17, // except 1d tensors
  410. GGML_FTYPE_MOSTLY_IQ1_S = 18, // except 1d tensors
  411. GGML_FTYPE_MOSTLY_IQ4_NL = 19, // except 1d tensors
  412. GGML_FTYPE_MOSTLY_IQ3_S = 20, // except 1d tensors
  413. GGML_FTYPE_MOSTLY_IQ2_S = 21, // except 1d tensors
  414. GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
  415. GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
  416. GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
  417. GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
  418. };
  419. // available tensor operations:
  420. enum ggml_op {
  421. GGML_OP_NONE = 0,
  422. GGML_OP_DUP,
  423. GGML_OP_ADD,
  424. GGML_OP_ADD_ID,
  425. GGML_OP_ADD1,
  426. GGML_OP_ACC,
  427. GGML_OP_SUB,
  428. GGML_OP_MUL,
  429. GGML_OP_DIV,
  430. GGML_OP_SQR,
  431. GGML_OP_SQRT,
  432. GGML_OP_LOG,
  433. GGML_OP_SIN,
  434. GGML_OP_COS,
  435. GGML_OP_SUM,
  436. GGML_OP_SUM_ROWS,
  437. GGML_OP_MEAN,
  438. GGML_OP_ARGMAX,
  439. GGML_OP_COUNT_EQUAL,
  440. GGML_OP_REPEAT,
  441. GGML_OP_REPEAT_BACK,
  442. GGML_OP_CONCAT,
  443. GGML_OP_SILU_BACK,
  444. GGML_OP_NORM, // normalize
  445. GGML_OP_RMS_NORM,
  446. GGML_OP_RMS_NORM_BACK,
  447. GGML_OP_GROUP_NORM,
  448. GGML_OP_L2_NORM,
  449. GGML_OP_MUL_MAT,
  450. GGML_OP_MUL_MAT_ID,
  451. GGML_OP_OUT_PROD,
  452. GGML_OP_SCALE,
  453. GGML_OP_SET,
  454. GGML_OP_CPY,
  455. GGML_OP_CONT,
  456. GGML_OP_RESHAPE,
  457. GGML_OP_VIEW,
  458. GGML_OP_PERMUTE,
  459. GGML_OP_TRANSPOSE,
  460. GGML_OP_GET_ROWS,
  461. GGML_OP_GET_ROWS_BACK,
  462. GGML_OP_SET_ROWS,
  463. GGML_OP_DIAG,
  464. GGML_OP_DIAG_MASK_INF,
  465. GGML_OP_DIAG_MASK_ZERO,
  466. GGML_OP_SOFT_MAX,
  467. GGML_OP_SOFT_MAX_BACK,
  468. GGML_OP_ROPE,
  469. GGML_OP_ROPE_BACK,
  470. GGML_OP_CLAMP,
  471. GGML_OP_CONV_TRANSPOSE_1D,
  472. GGML_OP_IM2COL,
  473. GGML_OP_IM2COL_BACK,
  474. GGML_OP_CONV_2D,
  475. GGML_OP_CONV_3D,
  476. GGML_OP_CONV_2D_DW,
  477. GGML_OP_CONV_TRANSPOSE_2D,
  478. GGML_OP_POOL_1D,
  479. GGML_OP_POOL_2D,
  480. GGML_OP_POOL_2D_BACK,
  481. GGML_OP_UPSCALE,
  482. GGML_OP_PAD,
  483. GGML_OP_PAD_REFLECT_1D,
  484. GGML_OP_ROLL,
  485. GGML_OP_ARANGE,
  486. GGML_OP_TIMESTEP_EMBEDDING,
  487. GGML_OP_ARGSORT,
  488. GGML_OP_LEAKY_RELU,
  489. GGML_OP_FLASH_ATTN_EXT,
  490. GGML_OP_FLASH_ATTN_BACK,
  491. GGML_OP_SSM_CONV,
  492. GGML_OP_SSM_SCAN,
  493. GGML_OP_WIN_PART,
  494. GGML_OP_WIN_UNPART,
  495. GGML_OP_GET_REL_POS,
  496. GGML_OP_ADD_REL_POS,
  497. GGML_OP_RWKV_WKV6,
  498. GGML_OP_GATED_LINEAR_ATTN,
  499. GGML_OP_RWKV_WKV7,
  500. GGML_OP_UNARY,
  501. GGML_OP_MAP_CUSTOM1,
  502. GGML_OP_MAP_CUSTOM2,
  503. GGML_OP_MAP_CUSTOM3,
  504. GGML_OP_CUSTOM,
  505. GGML_OP_CROSS_ENTROPY_LOSS,
  506. GGML_OP_CROSS_ENTROPY_LOSS_BACK,
  507. GGML_OP_OPT_STEP_ADAMW,
  508. GGML_OP_OPT_STEP_SGD,
  509. GGML_OP_GLU,
  510. GGML_OP_COUNT,
  511. };
  512. enum ggml_unary_op {
  513. GGML_UNARY_OP_ABS,
  514. GGML_UNARY_OP_SGN,
  515. GGML_UNARY_OP_NEG,
  516. GGML_UNARY_OP_STEP,
  517. GGML_UNARY_OP_TANH,
  518. GGML_UNARY_OP_ELU,
  519. GGML_UNARY_OP_RELU,
  520. GGML_UNARY_OP_SIGMOID,
  521. GGML_UNARY_OP_GELU,
  522. GGML_UNARY_OP_GELU_QUICK,
  523. GGML_UNARY_OP_SILU,
  524. GGML_UNARY_OP_HARDSWISH,
  525. GGML_UNARY_OP_HARDSIGMOID,
  526. GGML_UNARY_OP_EXP,
  527. GGML_UNARY_OP_GELU_ERF,
  528. GGML_UNARY_OP_COUNT,
  529. };
  530. enum ggml_glu_op {
  531. GGML_GLU_OP_REGLU,
  532. GGML_GLU_OP_GEGLU,
  533. GGML_GLU_OP_SWIGLU,
  534. GGML_GLU_OP_SWIGLU_OAI,
  535. GGML_GLU_OP_GEGLU_ERF,
  536. GGML_GLU_OP_GEGLU_QUICK,
  537. GGML_GLU_OP_COUNT,
  538. };
  539. enum ggml_object_type {
  540. GGML_OBJECT_TYPE_TENSOR,
  541. GGML_OBJECT_TYPE_GRAPH,
  542. GGML_OBJECT_TYPE_WORK_BUFFER
  543. };
  544. enum ggml_log_level {
  545. GGML_LOG_LEVEL_NONE = 0,
  546. GGML_LOG_LEVEL_DEBUG = 1,
  547. GGML_LOG_LEVEL_INFO = 2,
  548. GGML_LOG_LEVEL_WARN = 3,
  549. GGML_LOG_LEVEL_ERROR = 4,
  550. GGML_LOG_LEVEL_CONT = 5, // continue previous log
  551. };
  552. // this tensor...
  553. enum ggml_tensor_flag {
  554. GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
  555. GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
  556. GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
  557. GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
  558. };
  559. struct ggml_init_params {
  560. // memory pool
  561. size_t mem_size; // bytes
  562. void * mem_buffer; // if NULL, memory will be allocated internally
  563. bool no_alloc; // don't allocate memory for the tensor data
  564. };
  565. // n-dimensional tensor
  566. struct ggml_tensor {
  567. enum ggml_type type;
  568. struct ggml_backend_buffer * buffer;
  569. int64_t ne[GGML_MAX_DIMS]; // number of elements
  570. size_t nb[GGML_MAX_DIMS]; // stride in bytes:
  571. // nb[0] = ggml_type_size(type)
  572. // nb[1] = nb[0] * (ne[0] / ggml_blck_size(type)) + padding
  573. // nb[i] = nb[i-1] * ne[i-1]
  574. // compute data
  575. enum ggml_op op;
  576. // op params - allocated as int32_t for alignment
  577. int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
  578. int32_t flags;
  579. struct ggml_tensor * src[GGML_MAX_SRC];
  580. // source tensor and offset for views
  581. struct ggml_tensor * view_src;
  582. size_t view_offs;
  583. void * data;
  584. char name[GGML_MAX_NAME];
  585. void * extra; // extra things e.g. for ggml-cuda.cu
  586. char padding[8];
  587. };
  588. static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
  589. // Abort callback
  590. // If not NULL, called before ggml computation
  591. // If it returns true, the computation is aborted
  592. typedef bool (*ggml_abort_callback)(void * data);
  593. //
  594. // GUID
  595. //
  596. // GUID types
  597. typedef uint8_t ggml_guid[16];
  598. typedef ggml_guid * ggml_guid_t;
  599. GGML_API bool ggml_guid_matches(ggml_guid_t guid_a, ggml_guid_t guid_b);
  600. // misc
  601. GGML_API const char * ggml_version(void);
  602. GGML_API const char * ggml_commit(void);
  603. GGML_API void ggml_time_init(void); // call this once at the beginning of the program
  604. GGML_API int64_t ggml_time_ms(void);
  605. GGML_API int64_t ggml_time_us(void);
  606. GGML_API int64_t ggml_cycles(void);
  607. GGML_API int64_t ggml_cycles_per_ms(void);
  608. // accepts a UTF-8 path, even on Windows
  609. GGML_API FILE * ggml_fopen(const char * fname, const char * mode);
  610. GGML_API void ggml_print_object (const struct ggml_object * obj);
  611. GGML_API void ggml_print_objects(const struct ggml_context * ctx);
  612. GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
  613. GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
  614. GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
  615. GGML_API size_t ggml_nbytes_pad(const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
  616. GGML_API int64_t ggml_blck_size(enum ggml_type type);
  617. GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
  618. GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
  619. GGML_DEPRECATED(
  620. GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
  621. "use ggml_row_size() instead");
  622. GGML_API const char * ggml_type_name(enum ggml_type type);
  623. GGML_API const char * ggml_op_name (enum ggml_op op);
  624. GGML_API const char * ggml_op_symbol(enum ggml_op op);
  625. GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
  626. GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
  627. GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
  628. GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
  629. GGML_API bool ggml_is_quantized(enum ggml_type type);
  630. // TODO: temporary until model loading of ggml examples is refactored
  631. GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
  632. GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
  633. GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
  634. GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor);
  635. GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
  636. GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
  637. GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
  638. GGML_API bool ggml_is_3d (const struct ggml_tensor * tensor);
  639. GGML_API int ggml_n_dims (const struct ggml_tensor * tensor); // returns 1 for scalars
  640. // returns whether the tensor elements can be iterated over with a flattened index (no gaps, no permutation)
  641. GGML_API bool ggml_is_contiguous (const struct ggml_tensor * tensor);
  642. GGML_API bool ggml_is_contiguous_0(const struct ggml_tensor * tensor); // same as ggml_is_contiguous()
  643. GGML_API bool ggml_is_contiguous_1(const struct ggml_tensor * tensor); // contiguous for dims >= 1
  644. GGML_API bool ggml_is_contiguous_2(const struct ggml_tensor * tensor); // contiguous for dims >= 2
  645. // returns whether the tensor elements are allocated as one contiguous block of memory (no gaps, but permutation ok)
  646. GGML_API bool ggml_is_contiguously_allocated(const struct ggml_tensor * tensor);
  647. // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
  648. GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
  649. // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
  650. GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
  651. GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
  652. GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
  653. GGML_API bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
  654. // use this to compute the memory overhead of a tensor
  655. GGML_API size_t ggml_tensor_overhead(void);
  656. GGML_API bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes);
  657. // main
  658. GGML_API struct ggml_context * ggml_init (struct ggml_init_params params);
  659. GGML_API void ggml_reset(struct ggml_context * ctx);
  660. GGML_API void ggml_free (struct ggml_context * ctx);
  661. GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
  662. GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx);
  663. GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
  664. GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
  665. GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx);
  666. GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx);
  667. GGML_API struct ggml_tensor * ggml_new_tensor(
  668. struct ggml_context * ctx,
  669. enum ggml_type type,
  670. int n_dims,
  671. const int64_t *ne);
  672. GGML_API struct ggml_tensor * ggml_new_tensor_1d(
  673. struct ggml_context * ctx,
  674. enum ggml_type type,
  675. int64_t ne0);
  676. GGML_API struct ggml_tensor * ggml_new_tensor_2d(
  677. struct ggml_context * ctx,
  678. enum ggml_type type,
  679. int64_t ne0,
  680. int64_t ne1);
  681. GGML_API struct ggml_tensor * ggml_new_tensor_3d(
  682. struct ggml_context * ctx,
  683. enum ggml_type type,
  684. int64_t ne0,
  685. int64_t ne1,
  686. int64_t ne2);
  687. GGML_API struct ggml_tensor * ggml_new_tensor_4d(
  688. struct ggml_context * ctx,
  689. enum ggml_type type,
  690. int64_t ne0,
  691. int64_t ne1,
  692. int64_t ne2,
  693. int64_t ne3);
  694. GGML_API void * ggml_new_buffer(struct ggml_context * ctx, size_t nbytes);
  695. GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
  696. GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
  697. // Context tensor enumeration and lookup
  698. GGML_API struct ggml_tensor * ggml_get_first_tensor(const struct ggml_context * ctx);
  699. GGML_API struct ggml_tensor * ggml_get_next_tensor (const struct ggml_context * ctx, struct ggml_tensor * tensor);
  700. GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
  701. // Converts a flat index into coordinates
  702. GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
  703. GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
  704. GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
  705. GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
  706. GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
  707. GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor);
  708. GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name);
  709. GGML_ATTRIBUTE_FORMAT(2, 3)
  710. GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
  711. // Tensor flags
  712. GGML_API void ggml_set_input(struct ggml_tensor * tensor);
  713. GGML_API void ggml_set_output(struct ggml_tensor * tensor);
  714. GGML_API void ggml_set_param(struct ggml_tensor * tensor);
  715. GGML_API void ggml_set_loss(struct ggml_tensor * tensor);
  716. //
  717. // operations on tensors with backpropagation
  718. //
  719. GGML_API struct ggml_tensor * ggml_dup(
  720. struct ggml_context * ctx,
  721. struct ggml_tensor * a);
  722. // in-place, returns view(a)
  723. GGML_API struct ggml_tensor * ggml_dup_inplace(
  724. struct ggml_context * ctx,
  725. struct ggml_tensor * a);
  726. GGML_API struct ggml_tensor * ggml_add(
  727. struct ggml_context * ctx,
  728. struct ggml_tensor * a,
  729. struct ggml_tensor * b);
  730. GGML_API struct ggml_tensor * ggml_add_inplace(
  731. struct ggml_context * ctx,
  732. struct ggml_tensor * a,
  733. struct ggml_tensor * b);
  734. GGML_API struct ggml_tensor * ggml_add_cast(
  735. struct ggml_context * ctx,
  736. struct ggml_tensor * a,
  737. struct ggml_tensor * b,
  738. enum ggml_type type);
  739. // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
  740. GGML_API struct ggml_tensor * ggml_add_id(
  741. struct ggml_context * ctx,
  742. struct ggml_tensor * a,
  743. struct ggml_tensor * b,
  744. struct ggml_tensor * ids);
  745. GGML_API struct ggml_tensor * ggml_add1(
  746. struct ggml_context * ctx,
  747. struct ggml_tensor * a,
  748. struct ggml_tensor * b);
  749. GGML_API struct ggml_tensor * ggml_add1_inplace(
  750. struct ggml_context * ctx,
  751. struct ggml_tensor * a,
  752. struct ggml_tensor * b);
  753. // dst = a
  754. // view(dst, nb1, nb2, nb3, offset) += b
  755. // return dst
  756. GGML_API struct ggml_tensor * ggml_acc(
  757. struct ggml_context * ctx,
  758. struct ggml_tensor * a,
  759. struct ggml_tensor * b,
  760. size_t nb1,
  761. size_t nb2,
  762. size_t nb3,
  763. size_t offset);
  764. GGML_API struct ggml_tensor * ggml_acc_inplace(
  765. struct ggml_context * ctx,
  766. struct ggml_tensor * a,
  767. struct ggml_tensor * b,
  768. size_t nb1,
  769. size_t nb2,
  770. size_t nb3,
  771. size_t offset);
  772. GGML_API struct ggml_tensor * ggml_sub(
  773. struct ggml_context * ctx,
  774. struct ggml_tensor * a,
  775. struct ggml_tensor * b);
  776. GGML_API struct ggml_tensor * ggml_sub_inplace(
  777. struct ggml_context * ctx,
  778. struct ggml_tensor * a,
  779. struct ggml_tensor * b);
  780. GGML_API struct ggml_tensor * ggml_mul(
  781. struct ggml_context * ctx,
  782. struct ggml_tensor * a,
  783. struct ggml_tensor * b);
  784. GGML_API struct ggml_tensor * ggml_mul_inplace(
  785. struct ggml_context * ctx,
  786. struct ggml_tensor * a,
  787. struct ggml_tensor * b);
  788. GGML_API struct ggml_tensor * ggml_div(
  789. struct ggml_context * ctx,
  790. struct ggml_tensor * a,
  791. struct ggml_tensor * b);
  792. GGML_API struct ggml_tensor * ggml_div_inplace(
  793. struct ggml_context * ctx,
  794. struct ggml_tensor * a,
  795. struct ggml_tensor * b);
  796. GGML_API struct ggml_tensor * ggml_sqr(
  797. struct ggml_context * ctx,
  798. struct ggml_tensor * a);
  799. GGML_API struct ggml_tensor * ggml_sqr_inplace(
  800. struct ggml_context * ctx,
  801. struct ggml_tensor * a);
  802. GGML_API struct ggml_tensor * ggml_sqrt(
  803. struct ggml_context * ctx,
  804. struct ggml_tensor * a);
  805. GGML_API struct ggml_tensor * ggml_sqrt_inplace(
  806. struct ggml_context * ctx,
  807. struct ggml_tensor * a);
  808. GGML_API struct ggml_tensor * ggml_log(
  809. struct ggml_context * ctx,
  810. struct ggml_tensor * a);
  811. GGML_API struct ggml_tensor * ggml_log_inplace(
  812. struct ggml_context * ctx,
  813. struct ggml_tensor * a);
  814. GGML_API struct ggml_tensor * ggml_sin(
  815. struct ggml_context * ctx,
  816. struct ggml_tensor * a);
  817. GGML_API struct ggml_tensor * ggml_sin_inplace(
  818. struct ggml_context * ctx,
  819. struct ggml_tensor * a);
  820. GGML_API struct ggml_tensor * ggml_cos(
  821. struct ggml_context * ctx,
  822. struct ggml_tensor * a);
  823. GGML_API struct ggml_tensor * ggml_cos_inplace(
  824. struct ggml_context * ctx,
  825. struct ggml_tensor * a);
  826. // return scalar
  827. GGML_API struct ggml_tensor * ggml_sum(
  828. struct ggml_context * ctx,
  829. struct ggml_tensor * a);
  830. // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
  831. GGML_API struct ggml_tensor * ggml_sum_rows(
  832. struct ggml_context * ctx,
  833. struct ggml_tensor * a);
  834. // mean along rows
  835. GGML_API struct ggml_tensor * ggml_mean(
  836. struct ggml_context * ctx,
  837. struct ggml_tensor * a);
  838. // argmax along rows
  839. GGML_API struct ggml_tensor * ggml_argmax(
  840. struct ggml_context * ctx,
  841. struct ggml_tensor * a);
  842. // count number of equal elements in a and b
  843. GGML_API struct ggml_tensor * ggml_count_equal(
  844. struct ggml_context * ctx,
  845. struct ggml_tensor * a,
  846. struct ggml_tensor * b);
  847. // if a is the same shape as b, and a is not parameter, return a
  848. // otherwise, return a new tensor: repeat(a) to fit in b
  849. GGML_API struct ggml_tensor * ggml_repeat(
  850. struct ggml_context * ctx,
  851. struct ggml_tensor * a,
  852. struct ggml_tensor * b);
  853. // repeat a to the specified shape
  854. GGML_API struct ggml_tensor * ggml_repeat_4d(
  855. struct ggml_context * ctx,
  856. struct ggml_tensor * a,
  857. int64_t ne0,
  858. int64_t ne1,
  859. int64_t ne2,
  860. int64_t ne3);
  861. // sums repetitions in a into shape of b
  862. GGML_API struct ggml_tensor * ggml_repeat_back(
  863. struct ggml_context * ctx,
  864. struct ggml_tensor * a,
  865. struct ggml_tensor * b); // sum up values that are adjacent in dims > 0 instead of repeated with same stride
  866. // concat a and b along dim
  867. // used in stable-diffusion
  868. GGML_API struct ggml_tensor * ggml_concat(
  869. struct ggml_context * ctx,
  870. struct ggml_tensor * a,
  871. struct ggml_tensor * b,
  872. int dim);
  873. GGML_API struct ggml_tensor * ggml_abs(
  874. struct ggml_context * ctx,
  875. struct ggml_tensor * a);
  876. GGML_API struct ggml_tensor * ggml_abs_inplace(
  877. struct ggml_context * ctx,
  878. struct ggml_tensor * a);
  879. GGML_API struct ggml_tensor * ggml_sgn(
  880. struct ggml_context * ctx,
  881. struct ggml_tensor * a);
  882. GGML_API struct ggml_tensor * ggml_sgn_inplace(
  883. struct ggml_context * ctx,
  884. struct ggml_tensor * a);
  885. GGML_API struct ggml_tensor * ggml_neg(
  886. struct ggml_context * ctx,
  887. struct ggml_tensor * a);
  888. GGML_API struct ggml_tensor * ggml_neg_inplace(
  889. struct ggml_context * ctx,
  890. struct ggml_tensor * a);
  891. GGML_API struct ggml_tensor * ggml_step(
  892. struct ggml_context * ctx,
  893. struct ggml_tensor * a);
  894. GGML_API struct ggml_tensor * ggml_step_inplace(
  895. struct ggml_context * ctx,
  896. struct ggml_tensor * a);
  897. GGML_API struct ggml_tensor * ggml_tanh(
  898. struct ggml_context * ctx,
  899. struct ggml_tensor * a);
  900. GGML_API struct ggml_tensor * ggml_tanh_inplace(
  901. struct ggml_context * ctx,
  902. struct ggml_tensor * a);
  903. GGML_API struct ggml_tensor * ggml_elu(
  904. struct ggml_context * ctx,
  905. struct ggml_tensor * a);
  906. GGML_API struct ggml_tensor * ggml_elu_inplace(
  907. struct ggml_context * ctx,
  908. struct ggml_tensor * a);
  909. GGML_API struct ggml_tensor * ggml_relu(
  910. struct ggml_context * ctx,
  911. struct ggml_tensor * a);
  912. GGML_API struct ggml_tensor * ggml_leaky_relu(
  913. struct ggml_context * ctx,
  914. struct ggml_tensor * a, float negative_slope, bool inplace);
  915. GGML_API struct ggml_tensor * ggml_relu_inplace(
  916. struct ggml_context * ctx,
  917. struct ggml_tensor * a);
  918. GGML_API struct ggml_tensor * ggml_sigmoid(
  919. struct ggml_context * ctx,
  920. struct ggml_tensor * a);
  921. GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
  922. struct ggml_context * ctx,
  923. struct ggml_tensor * a);
  924. GGML_API struct ggml_tensor * ggml_gelu(
  925. struct ggml_context * ctx,
  926. struct ggml_tensor * a);
  927. GGML_API struct ggml_tensor * ggml_gelu_inplace(
  928. struct ggml_context * ctx,
  929. struct ggml_tensor * a);
  930. // GELU using erf (error function) when possible
  931. // some backends may fallback to approximation based on Abramowitz and Stegun formula
  932. GGML_API struct ggml_tensor * ggml_gelu_erf(
  933. struct ggml_context * ctx,
  934. struct ggml_tensor * a);
  935. GGML_API struct ggml_tensor * ggml_gelu_erf_inplace(
  936. struct ggml_context * ctx,
  937. struct ggml_tensor * a);
  938. GGML_API struct ggml_tensor * ggml_gelu_quick(
  939. struct ggml_context * ctx,
  940. struct ggml_tensor * a);
  941. GGML_API struct ggml_tensor * ggml_gelu_quick_inplace(
  942. struct ggml_context * ctx,
  943. struct ggml_tensor * a);
  944. GGML_API struct ggml_tensor * ggml_silu(
  945. struct ggml_context * ctx,
  946. struct ggml_tensor * a);
  947. GGML_API struct ggml_tensor * ggml_silu_inplace(
  948. struct ggml_context * ctx,
  949. struct ggml_tensor * a);
  950. // a - x
  951. // b - dy
  952. GGML_API struct ggml_tensor * ggml_silu_back(
  953. struct ggml_context * ctx,
  954. struct ggml_tensor * a,
  955. struct ggml_tensor * b);
  956. // hardswish(x) = x * relu6(x + 3) / 6
  957. GGML_API struct ggml_tensor * ggml_hardswish(
  958. struct ggml_context * ctx,
  959. struct ggml_tensor * a);
  960. // hardsigmoid(x) = relu6(x + 3) / 6
  961. GGML_API struct ggml_tensor * ggml_hardsigmoid(
  962. struct ggml_context * ctx,
  963. struct ggml_tensor * a);
  964. GGML_API struct ggml_tensor * ggml_exp(
  965. struct ggml_context * ctx,
  966. struct ggml_tensor * a);
  967. GGML_API struct ggml_tensor * ggml_exp_inplace(
  968. struct ggml_context * ctx,
  969. struct ggml_tensor * a);
  970. // gated linear unit ops
  971. // A: n columns, r rows,
  972. // result is n / 2 columns, r rows,
  973. // expects gate in second half of row, unless swapped is true
  974. GGML_API struct ggml_tensor * ggml_glu(
  975. struct ggml_context * ctx,
  976. struct ggml_tensor * a,
  977. enum ggml_glu_op op,
  978. bool swapped);
  979. GGML_API struct ggml_tensor * ggml_reglu(
  980. struct ggml_context * ctx,
  981. struct ggml_tensor * a);
  982. GGML_API struct ggml_tensor * ggml_reglu_swapped(
  983. struct ggml_context * ctx,
  984. struct ggml_tensor * a);
  985. GGML_API struct ggml_tensor * ggml_geglu(
  986. struct ggml_context * ctx,
  987. struct ggml_tensor * a);
  988. GGML_API struct ggml_tensor * ggml_geglu_swapped(
  989. struct ggml_context * ctx,
  990. struct ggml_tensor * a);
  991. GGML_API struct ggml_tensor * ggml_swiglu(
  992. struct ggml_context * ctx,
  993. struct ggml_tensor * a);
  994. GGML_API struct ggml_tensor * ggml_swiglu_swapped(
  995. struct ggml_context * ctx,
  996. struct ggml_tensor * a);
  997. GGML_API struct ggml_tensor * ggml_geglu_erf(
  998. struct ggml_context * ctx,
  999. struct ggml_tensor * a);
  1000. GGML_API struct ggml_tensor * ggml_geglu_erf_swapped(
  1001. struct ggml_context * ctx,
  1002. struct ggml_tensor * a);
  1003. GGML_API struct ggml_tensor * ggml_geglu_quick(
  1004. struct ggml_context * ctx,
  1005. struct ggml_tensor * a);
  1006. GGML_API struct ggml_tensor * ggml_geglu_quick_swapped(
  1007. struct ggml_context * ctx,
  1008. struct ggml_tensor * a);
  1009. // A: n columns, r rows,
  1010. // B: n columns, r rows,
  1011. GGML_API struct ggml_tensor * ggml_glu_split(
  1012. struct ggml_context * ctx,
  1013. struct ggml_tensor * a,
  1014. struct ggml_tensor * b,
  1015. enum ggml_glu_op op);
  1016. GGML_API struct ggml_tensor * ggml_reglu_split(
  1017. struct ggml_context * ctx,
  1018. struct ggml_tensor * a,
  1019. struct ggml_tensor * b);
  1020. GGML_API struct ggml_tensor * ggml_geglu_split(
  1021. struct ggml_context * ctx,
  1022. struct ggml_tensor * a,
  1023. struct ggml_tensor * b);
  1024. GGML_API struct ggml_tensor * ggml_swiglu_split(
  1025. struct ggml_context * ctx,
  1026. struct ggml_tensor * a,
  1027. struct ggml_tensor * b);
  1028. GGML_API struct ggml_tensor * ggml_geglu_erf_split(
  1029. struct ggml_context * ctx,
  1030. struct ggml_tensor * a,
  1031. struct ggml_tensor * b);
  1032. GGML_API struct ggml_tensor * ggml_geglu_quick_split(
  1033. struct ggml_context * ctx,
  1034. struct ggml_tensor * a,
  1035. struct ggml_tensor * b);
  1036. GGML_API struct ggml_tensor * ggml_swiglu_oai(
  1037. struct ggml_context * ctx,
  1038. struct ggml_tensor * a,
  1039. struct ggml_tensor * b,
  1040. float alpha,
  1041. float limit);
  1042. // normalize along rows
  1043. GGML_API struct ggml_tensor * ggml_norm(
  1044. struct ggml_context * ctx,
  1045. struct ggml_tensor * a,
  1046. float eps);
  1047. GGML_API struct ggml_tensor * ggml_norm_inplace(
  1048. struct ggml_context * ctx,
  1049. struct ggml_tensor * a,
  1050. float eps);
  1051. GGML_API struct ggml_tensor * ggml_rms_norm(
  1052. struct ggml_context * ctx,
  1053. struct ggml_tensor * a,
  1054. float eps);
  1055. GGML_API struct ggml_tensor * ggml_rms_norm_inplace(
  1056. struct ggml_context * ctx,
  1057. struct ggml_tensor * a,
  1058. float eps);
  1059. // group normalize along ne0*ne1*n_groups
  1060. // used in stable-diffusion
  1061. GGML_API struct ggml_tensor * ggml_group_norm(
  1062. struct ggml_context * ctx,
  1063. struct ggml_tensor * a,
  1064. int n_groups,
  1065. float eps);
  1066. GGML_API struct ggml_tensor * ggml_group_norm_inplace(
  1067. struct ggml_context * ctx,
  1068. struct ggml_tensor * a,
  1069. int n_groups,
  1070. float eps);
  1071. // l2 normalize along rows
  1072. // used in rwkv v7
  1073. GGML_API struct ggml_tensor * ggml_l2_norm(
  1074. struct ggml_context * ctx,
  1075. struct ggml_tensor * a,
  1076. float eps);
  1077. GGML_API struct ggml_tensor * ggml_l2_norm_inplace(
  1078. struct ggml_context * ctx,
  1079. struct ggml_tensor * a,
  1080. float eps);
  1081. // a - x
  1082. // b - dy
  1083. GGML_API struct ggml_tensor * ggml_rms_norm_back(
  1084. struct ggml_context * ctx,
  1085. struct ggml_tensor * a,
  1086. struct ggml_tensor * b,
  1087. float eps);
  1088. // A: k columns, n rows => [ne03, ne02, n, k]
  1089. // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k]
  1090. // result is n columns, m rows => [ne03 * x, ne02 * y, m, n]
  1091. GGML_API struct ggml_tensor * ggml_mul_mat(
  1092. struct ggml_context * ctx,
  1093. struct ggml_tensor * a,
  1094. struct ggml_tensor * b);
  1095. // change the precision of a matrix multiplication
  1096. // set to GGML_PREC_F32 for higher precision (useful for phi-2)
  1097. GGML_API void ggml_mul_mat_set_prec(
  1098. struct ggml_tensor * a,
  1099. enum ggml_prec prec);
  1100. // indirect matrix multiplication
  1101. GGML_API struct ggml_tensor * ggml_mul_mat_id(
  1102. struct ggml_context * ctx,
  1103. struct ggml_tensor * as,
  1104. struct ggml_tensor * b,
  1105. struct ggml_tensor * ids);
  1106. // A: m columns, n rows,
  1107. // B: p columns, n rows,
  1108. // result is m columns, p rows
  1109. GGML_API struct ggml_tensor * ggml_out_prod(
  1110. struct ggml_context * ctx,
  1111. struct ggml_tensor * a,
  1112. struct ggml_tensor * b);
  1113. //
  1114. // operations on tensors without backpropagation
  1115. //
  1116. GGML_API struct ggml_tensor * ggml_scale(
  1117. struct ggml_context * ctx,
  1118. struct ggml_tensor * a,
  1119. float s);
  1120. // in-place, returns view(a)
  1121. GGML_API struct ggml_tensor * ggml_scale_inplace(
  1122. struct ggml_context * ctx,
  1123. struct ggml_tensor * a,
  1124. float s);
  1125. // x = s * a + b
  1126. GGML_API struct ggml_tensor * ggml_scale_bias(
  1127. struct ggml_context * ctx,
  1128. struct ggml_tensor * a,
  1129. float s,
  1130. float b);
  1131. GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
  1132. struct ggml_context * ctx,
  1133. struct ggml_tensor * a,
  1134. float s,
  1135. float b);
  1136. // b -> view(a,offset,nb1,nb2,3), return modified a
  1137. GGML_API struct ggml_tensor * ggml_set(
  1138. struct ggml_context * ctx,
  1139. struct ggml_tensor * a,
  1140. struct ggml_tensor * b,
  1141. size_t nb1,
  1142. size_t nb2,
  1143. size_t nb3,
  1144. size_t offset); // in bytes
  1145. // b -> view(a,offset,nb1,nb2,3), return view(a)
  1146. GGML_API struct ggml_tensor * ggml_set_inplace(
  1147. struct ggml_context * ctx,
  1148. struct ggml_tensor * a,
  1149. struct ggml_tensor * b,
  1150. size_t nb1,
  1151. size_t nb2,
  1152. size_t nb3,
  1153. size_t offset); // in bytes
  1154. GGML_API struct ggml_tensor * ggml_set_1d(
  1155. struct ggml_context * ctx,
  1156. struct ggml_tensor * a,
  1157. struct ggml_tensor * b,
  1158. size_t offset); // in bytes
  1159. GGML_API struct ggml_tensor * ggml_set_1d_inplace(
  1160. struct ggml_context * ctx,
  1161. struct ggml_tensor * a,
  1162. struct ggml_tensor * b,
  1163. size_t offset); // in bytes
  1164. // b -> view(a,offset,nb1,nb2,3), return modified a
  1165. GGML_API struct ggml_tensor * ggml_set_2d(
  1166. struct ggml_context * ctx,
  1167. struct ggml_tensor * a,
  1168. struct ggml_tensor * b,
  1169. size_t nb1,
  1170. size_t offset); // in bytes
  1171. // b -> view(a,offset,nb1,nb2,3), return view(a)
  1172. GGML_API struct ggml_tensor * ggml_set_2d_inplace(
  1173. struct ggml_context * ctx,
  1174. struct ggml_tensor * a,
  1175. struct ggml_tensor * b,
  1176. size_t nb1,
  1177. size_t offset); // in bytes
  1178. // a -> b, return view(b)
  1179. GGML_API struct ggml_tensor * ggml_cpy(
  1180. struct ggml_context * ctx,
  1181. struct ggml_tensor * a,
  1182. struct ggml_tensor * b);
  1183. GGML_API struct ggml_tensor * ggml_cast(
  1184. struct ggml_context * ctx,
  1185. struct ggml_tensor * a,
  1186. enum ggml_type type);
  1187. // make contiguous
  1188. GGML_API struct ggml_tensor * ggml_cont(
  1189. struct ggml_context * ctx,
  1190. struct ggml_tensor * a);
  1191. // make contiguous, with new shape
  1192. GGML_API struct ggml_tensor * ggml_cont_1d(
  1193. struct ggml_context * ctx,
  1194. struct ggml_tensor * a,
  1195. int64_t ne0);
  1196. GGML_API struct ggml_tensor * ggml_cont_2d(
  1197. struct ggml_context * ctx,
  1198. struct ggml_tensor * a,
  1199. int64_t ne0,
  1200. int64_t ne1);
  1201. GGML_API struct ggml_tensor * ggml_cont_3d(
  1202. struct ggml_context * ctx,
  1203. struct ggml_tensor * a,
  1204. int64_t ne0,
  1205. int64_t ne1,
  1206. int64_t ne2);
  1207. GGML_API struct ggml_tensor * ggml_cont_4d(
  1208. struct ggml_context * ctx,
  1209. struct ggml_tensor * a,
  1210. int64_t ne0,
  1211. int64_t ne1,
  1212. int64_t ne2,
  1213. int64_t ne3);
  1214. // return view(a), b specifies the new shape
  1215. // TODO: when we start computing gradient, make a copy instead of view
  1216. GGML_API struct ggml_tensor * ggml_reshape(
  1217. struct ggml_context * ctx,
  1218. struct ggml_tensor * a,
  1219. struct ggml_tensor * b);
  1220. // return view(a)
  1221. // TODO: when we start computing gradient, make a copy instead of view
  1222. GGML_API struct ggml_tensor * ggml_reshape_1d(
  1223. struct ggml_context * ctx,
  1224. struct ggml_tensor * a,
  1225. int64_t ne0);
  1226. GGML_API struct ggml_tensor * ggml_reshape_2d(
  1227. struct ggml_context * ctx,
  1228. struct ggml_tensor * a,
  1229. int64_t ne0,
  1230. int64_t ne1);
  1231. // return view(a)
  1232. // TODO: when we start computing gradient, make a copy instead of view
  1233. GGML_API struct ggml_tensor * ggml_reshape_3d(
  1234. struct ggml_context * ctx,
  1235. struct ggml_tensor * a,
  1236. int64_t ne0,
  1237. int64_t ne1,
  1238. int64_t ne2);
  1239. GGML_API struct ggml_tensor * ggml_reshape_4d(
  1240. struct ggml_context * ctx,
  1241. struct ggml_tensor * a,
  1242. int64_t ne0,
  1243. int64_t ne1,
  1244. int64_t ne2,
  1245. int64_t ne3);
  1246. // offset in bytes
  1247. GGML_API struct ggml_tensor * ggml_view_1d(
  1248. struct ggml_context * ctx,
  1249. struct ggml_tensor * a,
  1250. int64_t ne0,
  1251. size_t offset);
  1252. GGML_API struct ggml_tensor * ggml_view_2d(
  1253. struct ggml_context * ctx,
  1254. struct ggml_tensor * a,
  1255. int64_t ne0,
  1256. int64_t ne1,
  1257. size_t nb1, // row stride in bytes
  1258. size_t offset);
  1259. GGML_API struct ggml_tensor * ggml_view_3d(
  1260. struct ggml_context * ctx,
  1261. struct ggml_tensor * a,
  1262. int64_t ne0,
  1263. int64_t ne1,
  1264. int64_t ne2,
  1265. size_t nb1, // row stride in bytes
  1266. size_t nb2, // slice stride in bytes
  1267. size_t offset);
  1268. GGML_API struct ggml_tensor * ggml_view_4d(
  1269. struct ggml_context * ctx,
  1270. struct ggml_tensor * a,
  1271. int64_t ne0,
  1272. int64_t ne1,
  1273. int64_t ne2,
  1274. int64_t ne3,
  1275. size_t nb1, // row stride in bytes
  1276. size_t nb2, // slice stride in bytes
  1277. size_t nb3,
  1278. size_t offset);
  1279. GGML_API struct ggml_tensor * ggml_permute(
  1280. struct ggml_context * ctx,
  1281. struct ggml_tensor * a,
  1282. int axis0,
  1283. int axis1,
  1284. int axis2,
  1285. int axis3);
  1286. // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
  1287. GGML_API struct ggml_tensor * ggml_transpose(
  1288. struct ggml_context * ctx,
  1289. struct ggml_tensor * a);
  1290. // supports 3D: a->ne[2] == b->ne[1]
  1291. GGML_API struct ggml_tensor * ggml_get_rows(
  1292. struct ggml_context * ctx,
  1293. struct ggml_tensor * a, // data
  1294. struct ggml_tensor * b); // row indices
  1295. GGML_API struct ggml_tensor * ggml_get_rows_back(
  1296. struct ggml_context * ctx,
  1297. struct ggml_tensor * a, // gradients of ggml_get_rows result
  1298. struct ggml_tensor * b, // row indices
  1299. struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape
  1300. // a TD [n_embd, ne1, ne2, ne3]
  1301. // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3
  1302. // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1)
  1303. //
  1304. // undefined behavior if destination rows overlap
  1305. //
  1306. // broadcast:
  1307. // ne2 % ne11 == 0
  1308. // ne3 % ne12 == 0
  1309. //
  1310. // return view(a)
  1311. GGML_API struct ggml_tensor * ggml_set_rows(
  1312. struct ggml_context * ctx,
  1313. struct ggml_tensor * a, // destination
  1314. struct ggml_tensor * b, // source
  1315. struct ggml_tensor * c); // row indices
  1316. GGML_API struct ggml_tensor * ggml_diag(
  1317. struct ggml_context * ctx,
  1318. struct ggml_tensor * a);
  1319. // set elements above the diagonal to -INF
  1320. GGML_API struct ggml_tensor * ggml_diag_mask_inf(
  1321. struct ggml_context * ctx,
  1322. struct ggml_tensor * a,
  1323. int n_past);
  1324. // in-place, returns view(a)
  1325. GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(
  1326. struct ggml_context * ctx,
  1327. struct ggml_tensor * a,
  1328. int n_past);
  1329. // set elements above the diagonal to 0
  1330. GGML_API struct ggml_tensor * ggml_diag_mask_zero(
  1331. struct ggml_context * ctx,
  1332. struct ggml_tensor * a,
  1333. int n_past);
  1334. // in-place, returns view(a)
  1335. GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(
  1336. struct ggml_context * ctx,
  1337. struct ggml_tensor * a,
  1338. int n_past);
  1339. GGML_API struct ggml_tensor * ggml_soft_max(
  1340. struct ggml_context * ctx,
  1341. struct ggml_tensor * a);
  1342. // in-place, returns view(a)
  1343. GGML_API struct ggml_tensor * ggml_soft_max_inplace(
  1344. struct ggml_context * ctx,
  1345. struct ggml_tensor * a);
  1346. // a [ne0, ne01, ne02, ne03]
  1347. // mask [ne0, ne11, ne12, ne13] | ne11 >= ne01, F16 or F32, optional
  1348. //
  1349. // broadcast:
  1350. // ne02 % ne12 == 0
  1351. // ne03 % ne13 == 0
  1352. //
  1353. // fused soft_max(a*scale + mask*(ALiBi slope))
  1354. // max_bias = 0.0f for no ALiBi
  1355. GGML_API struct ggml_tensor * ggml_soft_max_ext(
  1356. struct ggml_context * ctx,
  1357. struct ggml_tensor * a,
  1358. struct ggml_tensor * mask,
  1359. float scale,
  1360. float max_bias);
  1361. GGML_API void ggml_soft_max_add_sinks(
  1362. struct ggml_tensor * a,
  1363. struct ggml_tensor * sinks);
  1364. GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
  1365. struct ggml_context * ctx,
  1366. struct ggml_tensor * a,
  1367. struct ggml_tensor * b,
  1368. float scale,
  1369. float max_bias);
  1370. // in-place, returns view(a)
  1371. GGML_API struct ggml_tensor * ggml_soft_max_ext_back_inplace(
  1372. struct ggml_context * ctx,
  1373. struct ggml_tensor * a,
  1374. struct ggml_tensor * b,
  1375. float scale,
  1376. float max_bias);
  1377. // rotary position embedding
  1378. // if (mode & 1) - skip n_past elements (NOT SUPPORTED)
  1379. // if (mode & GGML_ROPE_TYPE_NEOX) - GPT-NeoX style
  1380. //
  1381. // b is an int32 vector with size a->ne[2], it contains the positions
  1382. GGML_API struct ggml_tensor * ggml_rope(
  1383. struct ggml_context * ctx,
  1384. struct ggml_tensor * a,
  1385. struct ggml_tensor * b,
  1386. int n_dims,
  1387. int mode);
  1388. // in-place, returns view(a)
  1389. GGML_API struct ggml_tensor * ggml_rope_inplace(
  1390. struct ggml_context * ctx,
  1391. struct ggml_tensor * a,
  1392. struct ggml_tensor * b,
  1393. int n_dims,
  1394. int mode);
  1395. // custom RoPE
  1396. // c is freq factors (e.g. phi3-128k), (optional)
  1397. GGML_API struct ggml_tensor * ggml_rope_ext(
  1398. struct ggml_context * ctx,
  1399. struct ggml_tensor * a,
  1400. struct ggml_tensor * b,
  1401. struct ggml_tensor * c,
  1402. int n_dims,
  1403. int mode,
  1404. int n_ctx_orig,
  1405. float freq_base,
  1406. float freq_scale,
  1407. float ext_factor,
  1408. float attn_factor,
  1409. float beta_fast,
  1410. float beta_slow);
  1411. GGML_API struct ggml_tensor * ggml_rope_multi(
  1412. struct ggml_context * ctx,
  1413. struct ggml_tensor * a,
  1414. struct ggml_tensor * b,
  1415. struct ggml_tensor * c,
  1416. int n_dims,
  1417. int sections[GGML_MROPE_SECTIONS],
  1418. int mode,
  1419. int n_ctx_orig,
  1420. float freq_base,
  1421. float freq_scale,
  1422. float ext_factor,
  1423. float attn_factor,
  1424. float beta_fast,
  1425. float beta_slow);
  1426. // in-place, returns view(a)
  1427. GGML_API struct ggml_tensor * ggml_rope_ext_inplace(
  1428. struct ggml_context * ctx,
  1429. struct ggml_tensor * a,
  1430. struct ggml_tensor * b,
  1431. struct ggml_tensor * c,
  1432. int n_dims,
  1433. int mode,
  1434. int n_ctx_orig,
  1435. float freq_base,
  1436. float freq_scale,
  1437. float ext_factor,
  1438. float attn_factor,
  1439. float beta_fast,
  1440. float beta_slow);
  1441. GGML_API struct ggml_tensor * ggml_rope_multi_inplace(
  1442. struct ggml_context * ctx,
  1443. struct ggml_tensor * a,
  1444. struct ggml_tensor * b,
  1445. struct ggml_tensor * c,
  1446. int n_dims,
  1447. int sections[GGML_MROPE_SECTIONS],
  1448. int mode,
  1449. int n_ctx_orig,
  1450. float freq_base,
  1451. float freq_scale,
  1452. float ext_factor,
  1453. float attn_factor,
  1454. float beta_fast,
  1455. float beta_slow);
  1456. GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom(
  1457. struct ggml_context * ctx,
  1458. struct ggml_tensor * a,
  1459. struct ggml_tensor * b,
  1460. int n_dims,
  1461. int mode,
  1462. int n_ctx_orig,
  1463. float freq_base,
  1464. float freq_scale,
  1465. float ext_factor,
  1466. float attn_factor,
  1467. float beta_fast,
  1468. float beta_slow),
  1469. "use ggml_rope_ext instead");
  1470. GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
  1471. struct ggml_context * ctx,
  1472. struct ggml_tensor * a,
  1473. struct ggml_tensor * b,
  1474. int n_dims,
  1475. int mode,
  1476. int n_ctx_orig,
  1477. float freq_base,
  1478. float freq_scale,
  1479. float ext_factor,
  1480. float attn_factor,
  1481. float beta_fast,
  1482. float beta_slow),
  1483. "use ggml_rope_ext_inplace instead");
  1484. // compute correction dims for YaRN RoPE scaling
  1485. GGML_API void ggml_rope_yarn_corr_dims(
  1486. int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
  1487. // rotary position embedding backward, i.e compute dx from dy
  1488. // a - dy
  1489. GGML_API struct ggml_tensor * ggml_rope_ext_back(
  1490. struct ggml_context * ctx,
  1491. struct ggml_tensor * a, // gradients of ggml_rope result
  1492. struct ggml_tensor * b, // positions
  1493. struct ggml_tensor * c, // freq factors
  1494. int n_dims,
  1495. int mode,
  1496. int n_ctx_orig,
  1497. float freq_base,
  1498. float freq_scale,
  1499. float ext_factor,
  1500. float attn_factor,
  1501. float beta_fast,
  1502. float beta_slow);
  1503. GGML_API struct ggml_tensor * ggml_rope_multi_back(
  1504. struct ggml_context * ctx,
  1505. struct ggml_tensor * a,
  1506. struct ggml_tensor * b,
  1507. struct ggml_tensor * c,
  1508. int n_dims,
  1509. int sections[4],
  1510. int mode,
  1511. int n_ctx_orig,
  1512. float freq_base,
  1513. float freq_scale,
  1514. float ext_factor,
  1515. float attn_factor,
  1516. float beta_fast,
  1517. float beta_slow);
  1518. // clamp
  1519. // in-place, returns view(a)
  1520. GGML_API struct ggml_tensor * ggml_clamp(
  1521. struct ggml_context * ctx,
  1522. struct ggml_tensor * a,
  1523. float min,
  1524. float max);
  1525. // im2col
  1526. // converts data into a format that effectively results in a convolution when combined with matrix multiplication
  1527. GGML_API struct ggml_tensor * ggml_im2col(
  1528. struct ggml_context * ctx,
  1529. struct ggml_tensor * a, // convolution kernel
  1530. struct ggml_tensor * b, // data
  1531. int s0, // stride dimension 0
  1532. int s1, // stride dimension 1
  1533. int p0, // padding dimension 0
  1534. int p1, // padding dimension 1
  1535. int d0, // dilation dimension 0
  1536. int d1, // dilation dimension 1
  1537. bool is_2D,
  1538. enum ggml_type dst_type);
  1539. GGML_API struct ggml_tensor * ggml_im2col_back(
  1540. struct ggml_context * ctx,
  1541. struct ggml_tensor * a, // convolution kernel
  1542. struct ggml_tensor * b, // gradient of im2col output
  1543. int64_t * ne, // shape of im2col input
  1544. int s0, // stride dimension 0
  1545. int s1, // stride dimension 1
  1546. int p0, // padding dimension 0
  1547. int p1, // padding dimension 1
  1548. int d0, // dilation dimension 0
  1549. int d1, // dilation dimension 1
  1550. bool is_2D);
  1551. GGML_API struct ggml_tensor * ggml_conv_1d(
  1552. struct ggml_context * ctx,
  1553. struct ggml_tensor * a, // convolution kernel
  1554. struct ggml_tensor * b, // data
  1555. int s0, // stride
  1556. int p0, // padding
  1557. int d0); // dilation
  1558. // conv_1d with padding = half
  1559. // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
  1560. GGML_API struct ggml_tensor* ggml_conv_1d_ph(
  1561. struct ggml_context * ctx,
  1562. struct ggml_tensor * a, // convolution kernel
  1563. struct ggml_tensor * b, // data
  1564. int s, // stride
  1565. int d); // dilation
  1566. // depthwise
  1567. // TODO: this is very likely wrong for some cases! - needs more testing
  1568. GGML_API struct ggml_tensor * ggml_conv_1d_dw(
  1569. struct ggml_context * ctx,
  1570. struct ggml_tensor * a, // convolution kernel
  1571. struct ggml_tensor * b, // data
  1572. int s0, // stride
  1573. int p0, // padding
  1574. int d0); // dilation
  1575. GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph(
  1576. struct ggml_context * ctx,
  1577. struct ggml_tensor * a, // convolution kernel
  1578. struct ggml_tensor * b, // data
  1579. int s0, // stride
  1580. int d0); // dilation
  1581. GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
  1582. struct ggml_context * ctx,
  1583. struct ggml_tensor * a, // convolution kernel
  1584. struct ggml_tensor * b, // data
  1585. int s0, // stride
  1586. int p0, // padding
  1587. int d0); // dilation
  1588. GGML_API struct ggml_tensor * ggml_conv_2d(
  1589. struct ggml_context * ctx,
  1590. struct ggml_tensor * a, // convolution kernel
  1591. struct ggml_tensor * b, // data
  1592. int s0, // stride dimension 0
  1593. int s1, // stride dimension 1
  1594. int p0, // padding dimension 0
  1595. int p1, // padding dimension 1
  1596. int d0, // dilation dimension 0
  1597. int d1); // dilation dimension 1
  1598. // kernel size is a->ne[0] x a->ne[1]
  1599. // stride is equal to kernel size
  1600. // padding is zero
  1601. // example:
  1602. // a: 16 16 3 768
  1603. // b: 1024 1024 3 1
  1604. // res: 64 64 768 1
  1605. // used in sam
  1606. GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
  1607. struct ggml_context * ctx,
  1608. struct ggml_tensor * a,
  1609. struct ggml_tensor * b);
  1610. // kernel size is a->ne[0] x a->ne[1]
  1611. // stride is 1
  1612. // padding is half
  1613. // example:
  1614. // a: 3 3 256 256
  1615. // b: 64 64 256 1
  1616. // res: 64 64 256 1
  1617. // used in sam
  1618. GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph(
  1619. struct ggml_context * ctx,
  1620. struct ggml_tensor * a,
  1621. struct ggml_tensor * b);
  1622. // depthwise (via im2col and mul_mat)
  1623. GGML_API struct ggml_tensor * ggml_conv_2d_dw(
  1624. struct ggml_context * ctx,
  1625. struct ggml_tensor * a, // convolution kernel
  1626. struct ggml_tensor * b, // data
  1627. int s0, // stride dimension 0
  1628. int s1, // stride dimension 1
  1629. int p0, // padding dimension 0
  1630. int p1, // padding dimension 1
  1631. int d0, // dilation dimension 0
  1632. int d1); // dilation dimension 1
  1633. // Depthwise 2D convolution
  1634. // may be faster than ggml_conv_2d_dw, but not available in all backends
  1635. // a: KW KH 1 C convolution kernel
  1636. // b: W H C N input data
  1637. // res: W_out H_out C N
  1638. GGML_API struct ggml_tensor * ggml_conv_2d_dw_direct(
  1639. struct ggml_context * ctx,
  1640. struct ggml_tensor * a,
  1641. struct ggml_tensor * b,
  1642. int stride0,
  1643. int stride1,
  1644. int pad0,
  1645. int pad1,
  1646. int dilation0,
  1647. int dilation1);
  1648. GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
  1649. struct ggml_context * ctx,
  1650. struct ggml_tensor * a,
  1651. struct ggml_tensor * b,
  1652. int stride);
  1653. GGML_API struct ggml_tensor * ggml_conv_2d_direct(
  1654. struct ggml_context * ctx,
  1655. struct ggml_tensor * a, // convolution kernel [KW, KH, IC, OC]
  1656. struct ggml_tensor * b, // input data [W, H, C, N]
  1657. int s0, // stride dimension 0
  1658. int s1, // stride dimension 1
  1659. int p0, // padding dimension 0
  1660. int p1, // padding dimension 1
  1661. int d0, // dilation dimension 0
  1662. int d1); // dilation dimension 1
  1663. GGML_API struct ggml_tensor * ggml_conv_3d(
  1664. struct ggml_context * ctx,
  1665. struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
  1666. struct ggml_tensor * b, // input [W, H, D, C * N]
  1667. int s0, // stride
  1668. int s1,
  1669. int s2,
  1670. int p0, // padding
  1671. int p1,
  1672. int p2,
  1673. int d0, // dilation
  1674. int d1,
  1675. int d2,
  1676. int n_channels,
  1677. int n_batch,
  1678. int n_channels_out);
  1679. enum ggml_op_pool {
  1680. GGML_OP_POOL_MAX,
  1681. GGML_OP_POOL_AVG,
  1682. GGML_OP_POOL_COUNT,
  1683. };
  1684. GGML_API struct ggml_tensor * ggml_pool_1d(
  1685. struct ggml_context * ctx,
  1686. struct ggml_tensor * a,
  1687. enum ggml_op_pool op,
  1688. int k0, // kernel size
  1689. int s0, // stride
  1690. int p0); // padding
  1691. // the result will have 2*p0 padding for the first dimension
  1692. // and 2*p1 padding for the second dimension
  1693. GGML_API struct ggml_tensor * ggml_pool_2d(
  1694. struct ggml_context * ctx,
  1695. struct ggml_tensor * a,
  1696. enum ggml_op_pool op,
  1697. int k0,
  1698. int k1,
  1699. int s0,
  1700. int s1,
  1701. float p0,
  1702. float p1);
  1703. GGML_API struct ggml_tensor * ggml_pool_2d_back(
  1704. struct ggml_context * ctx,
  1705. struct ggml_tensor * a,
  1706. struct ggml_tensor * af, // "a"/input used in forward pass
  1707. enum ggml_op_pool op,
  1708. int k0,
  1709. int k1,
  1710. int s0,
  1711. int s1,
  1712. float p0,
  1713. float p1);
  1714. enum ggml_scale_mode {
  1715. GGML_SCALE_MODE_NEAREST = 0,
  1716. GGML_SCALE_MODE_BILINEAR = 1,
  1717. GGML_SCALE_MODE_COUNT
  1718. };
  1719. enum ggml_scale_flag {
  1720. GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8)
  1721. };
  1722. // interpolate
  1723. // multiplies ne0 and ne1 by scale factor
  1724. GGML_API struct ggml_tensor * ggml_upscale(
  1725. struct ggml_context * ctx,
  1726. struct ggml_tensor * a,
  1727. int scale_factor,
  1728. enum ggml_scale_mode mode);
  1729. // interpolate
  1730. // interpolate scale to specified dimensions
  1731. GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_upscale_ext(
  1732. struct ggml_context * ctx,
  1733. struct ggml_tensor * a,
  1734. int ne0,
  1735. int ne1,
  1736. int ne2,
  1737. int ne3,
  1738. enum ggml_scale_mode mode),
  1739. "use ggml_interpolate instead");
  1740. // Up- or downsamples the input to the specified size.
  1741. // 2D scale modes (eg. bilinear) are applied to the first two dimensions.
  1742. GGML_API struct ggml_tensor * ggml_interpolate(
  1743. struct ggml_context * ctx,
  1744. struct ggml_tensor * a,
  1745. int64_t ne0,
  1746. int64_t ne1,
  1747. int64_t ne2,
  1748. int64_t ne3,
  1749. uint32_t mode); // ggml_scale_mode [ | ggml_scale_flag...]
  1750. // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
  1751. GGML_API struct ggml_tensor * ggml_pad(
  1752. struct ggml_context * ctx,
  1753. struct ggml_tensor * a,
  1754. int p0,
  1755. int p1,
  1756. int p2,
  1757. int p3);
  1758. // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
  1759. GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
  1760. struct ggml_context * ctx,
  1761. struct ggml_tensor * a,
  1762. int p0,
  1763. int p1);
  1764. // Move tensor elements by an offset given for each dimension. Elements that
  1765. // are shifted beyond the last position are wrapped around to the beginning.
  1766. GGML_API struct ggml_tensor * ggml_roll(
  1767. struct ggml_context * ctx,
  1768. struct ggml_tensor * a,
  1769. int shift0,
  1770. int shift1,
  1771. int shift2,
  1772. int shift3);
  1773. // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
  1774. // timesteps: [N,]
  1775. // return: [N, dim]
  1776. GGML_API struct ggml_tensor * ggml_timestep_embedding(
  1777. struct ggml_context * ctx,
  1778. struct ggml_tensor * timesteps,
  1779. int dim,
  1780. int max_period);
  1781. // sort rows
  1782. enum ggml_sort_order {
  1783. GGML_SORT_ORDER_ASC,
  1784. GGML_SORT_ORDER_DESC,
  1785. };
  1786. GGML_API struct ggml_tensor * ggml_argsort(
  1787. struct ggml_context * ctx,
  1788. struct ggml_tensor * a,
  1789. enum ggml_sort_order order);
  1790. GGML_API struct ggml_tensor * ggml_arange(
  1791. struct ggml_context * ctx,
  1792. float start,
  1793. float stop,
  1794. float step);
  1795. // top k elements per row
  1796. GGML_API struct ggml_tensor * ggml_top_k(
  1797. struct ggml_context * ctx,
  1798. struct ggml_tensor * a,
  1799. int k);
  1800. #define GGML_KQ_MASK_PAD 64
  1801. // q: [n_embd_k, n_batch, n_head, ne3 ]
  1802. // k: [n_embd_k, n_kv, n_head_kv, ne3 ]
  1803. // v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
  1804. // mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
  1805. // res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
  1806. //
  1807. // broadcast:
  1808. // n_head % n_head_kv == 0
  1809. // n_head % ne32 == 0
  1810. // ne3 % ne33 == 0
  1811. //
  1812. GGML_API struct ggml_tensor * ggml_flash_attn_ext(
  1813. struct ggml_context * ctx,
  1814. struct ggml_tensor * q,
  1815. struct ggml_tensor * k,
  1816. struct ggml_tensor * v,
  1817. struct ggml_tensor * mask,
  1818. float scale,
  1819. float max_bias,
  1820. float logit_softcap);
  1821. GGML_API void ggml_flash_attn_ext_set_prec(
  1822. struct ggml_tensor * a,
  1823. enum ggml_prec prec);
  1824. GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
  1825. const struct ggml_tensor * a);
  1826. GGML_API void ggml_flash_attn_ext_add_sinks(
  1827. struct ggml_tensor * a,
  1828. struct ggml_tensor * sinks);
  1829. // TODO: needs to be adapted to ggml_flash_attn_ext
  1830. GGML_API struct ggml_tensor * ggml_flash_attn_back(
  1831. struct ggml_context * ctx,
  1832. struct ggml_tensor * q,
  1833. struct ggml_tensor * k,
  1834. struct ggml_tensor * v,
  1835. struct ggml_tensor * d,
  1836. bool masked);
  1837. GGML_API struct ggml_tensor * ggml_ssm_conv(
  1838. struct ggml_context * ctx,
  1839. struct ggml_tensor * sx,
  1840. struct ggml_tensor * c);
  1841. GGML_API struct ggml_tensor * ggml_ssm_scan(
  1842. struct ggml_context * ctx,
  1843. struct ggml_tensor * s,
  1844. struct ggml_tensor * x,
  1845. struct ggml_tensor * dt,
  1846. struct ggml_tensor * A,
  1847. struct ggml_tensor * B,
  1848. struct ggml_tensor * C,
  1849. struct ggml_tensor * ids);
  1850. // partition into non-overlapping windows with padding if needed
  1851. // example:
  1852. // a: 768 64 64 1
  1853. // w: 14
  1854. // res: 768 14 14 25
  1855. // used in sam
  1856. GGML_API struct ggml_tensor * ggml_win_part(
  1857. struct ggml_context * ctx,
  1858. struct ggml_tensor * a,
  1859. int w);
  1860. // reverse of ggml_win_part
  1861. // used in sam
  1862. GGML_API struct ggml_tensor * ggml_win_unpart(
  1863. struct ggml_context * ctx,
  1864. struct ggml_tensor * a,
  1865. int w0,
  1866. int h0,
  1867. int w);
  1868. GGML_API struct ggml_tensor * ggml_unary(
  1869. struct ggml_context * ctx,
  1870. struct ggml_tensor * a,
  1871. enum ggml_unary_op op);
  1872. GGML_API struct ggml_tensor * ggml_unary_inplace(
  1873. struct ggml_context * ctx,
  1874. struct ggml_tensor * a,
  1875. enum ggml_unary_op op);
  1876. // used in sam
  1877. GGML_API struct ggml_tensor * ggml_get_rel_pos(
  1878. struct ggml_context * ctx,
  1879. struct ggml_tensor * a,
  1880. int qh,
  1881. int kh);
  1882. // used in sam
  1883. GGML_API struct ggml_tensor * ggml_add_rel_pos(
  1884. struct ggml_context * ctx,
  1885. struct ggml_tensor * a,
  1886. struct ggml_tensor * pw,
  1887. struct ggml_tensor * ph);
  1888. GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace(
  1889. struct ggml_context * ctx,
  1890. struct ggml_tensor * a,
  1891. struct ggml_tensor * pw,
  1892. struct ggml_tensor * ph);
  1893. GGML_API struct ggml_tensor * ggml_rwkv_wkv6(
  1894. struct ggml_context * ctx,
  1895. struct ggml_tensor * k,
  1896. struct ggml_tensor * v,
  1897. struct ggml_tensor * r,
  1898. struct ggml_tensor * tf,
  1899. struct ggml_tensor * td,
  1900. struct ggml_tensor * state);
  1901. GGML_API struct ggml_tensor * ggml_gated_linear_attn(
  1902. struct ggml_context * ctx,
  1903. struct ggml_tensor * k,
  1904. struct ggml_tensor * v,
  1905. struct ggml_tensor * q,
  1906. struct ggml_tensor * g,
  1907. struct ggml_tensor * state,
  1908. float scale);
  1909. GGML_API struct ggml_tensor * ggml_rwkv_wkv7(
  1910. struct ggml_context * ctx,
  1911. struct ggml_tensor * r,
  1912. struct ggml_tensor * w,
  1913. struct ggml_tensor * k,
  1914. struct ggml_tensor * v,
  1915. struct ggml_tensor * a,
  1916. struct ggml_tensor * b,
  1917. struct ggml_tensor * state);
  1918. // custom operators
  1919. typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
  1920. typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata);
  1921. typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata);
  1922. #define GGML_N_TASKS_MAX (-1)
  1923. // n_tasks == GGML_N_TASKS_MAX means to use max number of tasks
  1924. GGML_API struct ggml_tensor * ggml_map_custom1(
  1925. struct ggml_context * ctx,
  1926. struct ggml_tensor * a,
  1927. ggml_custom1_op_t fun,
  1928. int n_tasks,
  1929. void * userdata);
  1930. GGML_API struct ggml_tensor * ggml_map_custom1_inplace(
  1931. struct ggml_context * ctx,
  1932. struct ggml_tensor * a,
  1933. ggml_custom1_op_t fun,
  1934. int n_tasks,
  1935. void * userdata);
  1936. GGML_API struct ggml_tensor * ggml_map_custom2(
  1937. struct ggml_context * ctx,
  1938. struct ggml_tensor * a,
  1939. struct ggml_tensor * b,
  1940. ggml_custom2_op_t fun,
  1941. int n_tasks,
  1942. void * userdata);
  1943. GGML_API struct ggml_tensor * ggml_map_custom2_inplace(
  1944. struct ggml_context * ctx,
  1945. struct ggml_tensor * a,
  1946. struct ggml_tensor * b,
  1947. ggml_custom2_op_t fun,
  1948. int n_tasks,
  1949. void * userdata);
  1950. GGML_API struct ggml_tensor * ggml_map_custom3(
  1951. struct ggml_context * ctx,
  1952. struct ggml_tensor * a,
  1953. struct ggml_tensor * b,
  1954. struct ggml_tensor * c,
  1955. ggml_custom3_op_t fun,
  1956. int n_tasks,
  1957. void * userdata);
  1958. GGML_API struct ggml_tensor * ggml_map_custom3_inplace(
  1959. struct ggml_context * ctx,
  1960. struct ggml_tensor * a,
  1961. struct ggml_tensor * b,
  1962. struct ggml_tensor * c,
  1963. ggml_custom3_op_t fun,
  1964. int n_tasks,
  1965. void * userdata);
  1966. typedef void (*ggml_custom_op_t)(struct ggml_tensor * dst , int ith, int nth, void * userdata);
  1967. GGML_API struct ggml_tensor * ggml_custom_4d(
  1968. struct ggml_context * ctx,
  1969. enum ggml_type type,
  1970. int64_t ne0,
  1971. int64_t ne1,
  1972. int64_t ne2,
  1973. int64_t ne3,
  1974. struct ggml_tensor ** args,
  1975. int n_args,
  1976. ggml_custom_op_t fun,
  1977. int n_tasks,
  1978. void * userdata);
  1979. GGML_API struct ggml_tensor * ggml_custom_inplace(
  1980. struct ggml_context * ctx,
  1981. struct ggml_tensor * a,
  1982. struct ggml_tensor ** args,
  1983. int n_args,
  1984. ggml_custom_op_t fun,
  1985. int n_tasks,
  1986. void * userdata);
  1987. // loss function
  1988. GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
  1989. struct ggml_context * ctx,
  1990. struct ggml_tensor * a, // logits
  1991. struct ggml_tensor * b); // labels
  1992. GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
  1993. struct ggml_context * ctx,
  1994. struct ggml_tensor * a, // logits
  1995. struct ggml_tensor * b, // labels
  1996. struct ggml_tensor * c); // gradients of cross_entropy_loss result
  1997. // AdamW optimizer step
  1998. // Paper: https://arxiv.org/pdf/1711.05101v3.pdf
  1999. // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html
  2000. GGML_API struct ggml_tensor * ggml_opt_step_adamw(
  2001. struct ggml_context * ctx,
  2002. struct ggml_tensor * a,
  2003. struct ggml_tensor * grad,
  2004. struct ggml_tensor * m,
  2005. struct ggml_tensor * v,
  2006. struct ggml_tensor * adamw_params); // parameters such as the learning rate
  2007. // stochastic gradient descent step (with weight decay)
  2008. GGML_API struct ggml_tensor * ggml_opt_step_sgd(
  2009. struct ggml_context * ctx,
  2010. struct ggml_tensor * a,
  2011. struct ggml_tensor * grad,
  2012. struct ggml_tensor * sgd_params); // alpha, weight decay
  2013. //
  2014. // automatic differentiation
  2015. //
  2016. GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
  2017. GGML_API void ggml_build_backward_expand(
  2018. struct ggml_context * ctx, // context for gradient computation
  2019. struct ggml_cgraph * cgraph,
  2020. struct ggml_tensor ** grad_accs);
  2021. // graph allocation in a context
  2022. GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false
  2023. GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads);
  2024. GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph, bool force_grads);
  2025. GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst);
  2026. GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1
  2027. GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph);
  2028. GGML_API int ggml_graph_size (struct ggml_cgraph * cgraph);
  2029. GGML_API struct ggml_tensor * ggml_graph_node (struct ggml_cgraph * cgraph, int i); // if i < 0, returns nodes[n_nodes + i]
  2030. GGML_API struct ggml_tensor ** ggml_graph_nodes (struct ggml_cgraph * cgraph);
  2031. GGML_API int ggml_graph_n_nodes(struct ggml_cgraph * cgraph);
  2032. GGML_API void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
  2033. GGML_API size_t ggml_graph_overhead(void);
  2034. GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads);
  2035. GGML_API struct ggml_tensor * ggml_graph_get_tensor (const struct ggml_cgraph * cgraph, const char * name);
  2036. GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
  2037. GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node);
  2038. // print info and performance information for the graph
  2039. GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
  2040. // dump the graph into a file using the dot format
  2041. GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
  2042. // TODO these functions were sandwiched in the old optimization interface, is there a better place for them?
  2043. typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data);
  2044. // Set callback for all future logging events.
  2045. // If this is not called, or NULL is supplied, everything is output on stderr.
  2046. GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data);
  2047. GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
  2048. //
  2049. // quantization
  2050. //
  2051. // - ggml_quantize_init can be called multiple times with the same type
  2052. // it will only initialize the quantization tables for the first call or after ggml_quantize_free
  2053. // automatically called by ggml_quantize_chunk for convenience
  2054. //
  2055. // - ggml_quantize_free will free any memory allocated by ggml_quantize_init
  2056. // call this at the end of the program to avoid memory leaks
  2057. //
  2058. // note: these are thread-safe
  2059. //
  2060. GGML_API void ggml_quantize_init(enum ggml_type type);
  2061. GGML_API void ggml_quantize_free(void);
  2062. // some quantization type cannot be used without an importance matrix
  2063. GGML_API bool ggml_quantize_requires_imatrix(enum ggml_type type);
  2064. // calls ggml_quantize_init internally (i.e. can allocate memory)
  2065. GGML_API size_t ggml_quantize_chunk(
  2066. enum ggml_type type,
  2067. const float * src,
  2068. void * dst,
  2069. int64_t start,
  2070. int64_t nrows,
  2071. int64_t n_per_row,
  2072. const float * imatrix);
  2073. #ifdef __cplusplus
  2074. // restrict not standard in C++
  2075. # if defined(__GNUC__)
  2076. # define GGML_RESTRICT __restrict__
  2077. # elif defined(__clang__)
  2078. # define GGML_RESTRICT __restrict
  2079. # elif defined(_MSC_VER)
  2080. # define GGML_RESTRICT __restrict
  2081. # else
  2082. # define GGML_RESTRICT
  2083. # endif
  2084. #else
  2085. # if defined (_MSC_VER) && (__STDC_VERSION__ < 201112L)
  2086. # define GGML_RESTRICT __restrict
  2087. # else
  2088. # define GGML_RESTRICT restrict
  2089. # endif
  2090. #endif
  2091. typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
  2092. typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
  2093. struct ggml_type_traits {
  2094. const char * type_name;
  2095. int64_t blck_size;
  2096. int64_t blck_size_interleave; // interleave elements in blocks
  2097. size_t type_size;
  2098. bool is_quantized;
  2099. ggml_to_float_t to_float;
  2100. ggml_from_float_t from_float_ref;
  2101. };
  2102. GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type);
  2103. // ggml threadpool
  2104. // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend
  2105. // the goal should be to create an API that other backends can use move everything to the ggml base
  2106. // scheduling priorities
  2107. enum ggml_sched_priority {
  2108. GGML_SCHED_PRIO_LOW = -1,
  2109. GGML_SCHED_PRIO_NORMAL,
  2110. GGML_SCHED_PRIO_MEDIUM,
  2111. GGML_SCHED_PRIO_HIGH,
  2112. GGML_SCHED_PRIO_REALTIME
  2113. };
  2114. // threadpool params
  2115. // Use ggml_threadpool_params_default() or ggml_threadpool_params_init() to populate the defaults
  2116. struct ggml_threadpool_params {
  2117. bool cpumask[GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings)
  2118. int n_threads; // number of threads
  2119. enum ggml_sched_priority prio; // thread priority
  2120. uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling)
  2121. bool strict_cpu; // strict cpu placement
  2122. bool paused; // start in paused state
  2123. };
  2124. struct ggml_threadpool; // forward declaration, see ggml.c
  2125. typedef struct ggml_threadpool * ggml_threadpool_t;
  2126. GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads);
  2127. GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
  2128. GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
  2129. #ifdef __cplusplus
  2130. }
  2131. #endif