ggml.h 44 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319
  1. #pragma once
  2. //
  3. // GGML Tensor Library
  4. //
  5. // This documentation is still a work in progress.
  6. // If you wish some specific topics to be covered, feel free to drop a comment:
  7. //
  8. // https://github.com/ggerganov/whisper.cpp/issues/40
  9. //
  10. // ## Overview
  11. //
  12. // This library implements:
  13. //
  14. // - a set of tensor operations
  15. // - automatic differentiation
  16. // - basic optimization algorithms
  17. //
  18. // The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
  19. // but is not limited to, the following:
  20. //
  21. // - linear regression
  22. // - support vector machines
  23. // - neural networks
  24. //
  25. // The library allows the user to define a certain function using the available tensor operations. This function
  26. // definition is represented internally via a computation graph. Each tensor operation in the function definition
  27. // corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
  28. // function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
  29. // using one of the available optimization algorithms.
  30. //
  31. // For example, here we define the function: f(x) = a*x^2 + b
  32. //
  33. // {
  34. // struct ggml_init_params params = {
  35. // .mem_size = 16*1024*1024,
  36. // .mem_buffer = NULL,
  37. // };
  38. //
  39. // // memory allocation happens here
  40. // struct ggml_context * ctx = ggml_init(params);
  41. //
  42. // struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  43. //
  44. // ggml_set_param(ctx, x); // x is an input variable
  45. //
  46. // struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  47. // struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  48. // struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
  49. // struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
  50. //
  51. // ...
  52. // }
  53. //
  54. // Notice that the function definition above does not involve any actual computation. The computation is performed only
  55. // when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
  56. //
  57. // {
  58. // ...
  59. //
  60. // struct ggml_cgraph gf = ggml_build_forward(f);
  61. //
  62. // // set the input variable and parameter values
  63. // ggml_set_f32(x, 2.0f);
  64. // ggml_set_f32(a, 3.0f);
  65. // ggml_set_f32(b, 4.0f);
  66. //
  67. // ggml_graph_compute(ctx0, &gf);
  68. //
  69. // printf("f = %f\n", ggml_get_f32_1d(f, 0));
  70. //
  71. // ...
  72. // }
  73. //
  74. // The actual computation is performed in the ggml_graph_compute() function.
  75. //
  76. // The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
  77. // ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
  78. // in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
  79. // and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
  80. // actually needed.
  81. //
  82. // The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
  83. // differentiation and optimization algorithms.
  84. //
  85. // The described approach allows to define the function graph once and then compute its forward or backward graphs
  86. // multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
  87. // the user can avoid the memory allocation overhead at runtime.
  88. //
  89. // The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
  90. // citizens, but in theory the library can be extended to support FP8 and integer data types.
  91. //
  92. // Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
  93. // and binary operations. Most of the available operations fall into one of these two categories. With time, it became
  94. // clear that the library needs to support more complex operations. The way to support these operations is not clear
  95. // yet, but a few examples are demonstrated in the following operations:
  96. //
  97. // - ggml_permute()
  98. // - ggml_conv_1d_1s()
  99. // - ggml_conv_1d_2s()
  100. //
  101. // For each tensor operator, the library implements a forward and backward computation function. The forward function
  102. // computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
  103. // input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
  104. // calculus class, or watch the following video:
  105. //
  106. // What is Automatic Differentiation?
  107. // https://www.youtube.com/watch?v=wG_nF1awSSY
  108. //
  109. //
  110. // ## Tensor data (struct ggml_tensor)
  111. //
  112. // The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
  113. // the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
  114. // pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
  115. //
  116. // {
  117. // struct ggml_tensor * c = ggml_add(ctx, a, b);
  118. //
  119. // assert(c->src[0] == a);
  120. // assert(c->src[1] == b);
  121. // }
  122. //
  123. // The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
  124. // number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
  125. // to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
  126. // permutation. All tensor operations have to take the stride into account and not assume that the tensor is
  127. // contiguous in memory.
  128. //
  129. // The data of the tensor is accessed via the "data" pointer. For example:
  130. //
  131. // {
  132. // struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3);
  133. //
  134. // // a[1, 2] = 1.0f;
  135. // *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
  136. //
  137. // // a[2, 0] = 2.0f;
  138. // *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
  139. //
  140. // ...
  141. // }
  142. //
  143. // Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
  144. //
  145. // ## The matrix multiplication operator (ggml_mul_mat)
  146. //
  147. // TODO
  148. //
  149. //
  150. // ## Multi-threading
  151. //
  152. // TODO
  153. //
  154. //
  155. // ## Overview of ggml.c
  156. //
  157. // TODO
  158. //
  159. //
  160. // ## SIMD optimizations
  161. //
  162. // TODO
  163. //
  164. //
  165. // ## Debugging ggml
  166. //
  167. // TODO
  168. //
  169. //
  170. #ifdef GGML_SHARED
  171. # if defined(_WIN32) && !defined(__MINGW32__)
  172. # ifdef GGML_BUILD
  173. # define GGML_API __declspec(dllexport)
  174. # else
  175. # define GGML_API __declspec(dllimport)
  176. # endif
  177. # else
  178. # define GGML_API __attribute__ ((visibility ("default")))
  179. # endif
  180. #else
  181. # define GGML_API
  182. #endif
  183. #include <stdint.h>
  184. #include <stddef.h>
  185. #include <stdbool.h>
  186. #define GGML_FILE_MAGIC 0x67676d6c // "ggml"
  187. #define GGML_FILE_VERSION 1
  188. #define GGML_QNT_VERSION 2 // bump this on quantization format changes
  189. #define GGML_QNT_VERSION_FACTOR 1000 // do not change this
  190. #define GGML_MAX_DIMS 4
  191. #define GGML_MAX_NODES 4096
  192. #define GGML_MAX_PARAMS 256
  193. #define GGML_MAX_CONTEXTS 64
  194. #define GGML_MAX_OPT 4
  195. #define GGML_MAX_NAME 32
  196. #define GGML_DEFAULT_N_THREADS 4
  197. #define GGML_ASSERT(x) \
  198. do { \
  199. if (!(x)) { \
  200. fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \
  201. abort(); \
  202. } \
  203. } while (0)
  204. #ifdef __cplusplus
  205. extern "C" {
  206. #endif
  207. #ifdef __ARM_NEON
  208. // we use the built-in 16-bit float type
  209. typedef __fp16 ggml_fp16_t;
  210. #else
  211. typedef uint16_t ggml_fp16_t;
  212. #endif
  213. // convert FP16 <-> FP32
  214. GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x);
  215. GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x);
  216. GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, size_t n);
  217. GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, size_t n);
  218. struct ggml_object;
  219. struct ggml_context;
  220. enum ggml_type {
  221. GGML_TYPE_F32 = 0,
  222. GGML_TYPE_F16 = 1,
  223. GGML_TYPE_Q4_0 = 2,
  224. GGML_TYPE_Q4_1 = 3,
  225. // GGML_TYPE_Q4_2 = 4, support has been removed
  226. // GGML_TYPE_Q4_3 (5) support has been removed
  227. GGML_TYPE_Q5_0 = 6,
  228. GGML_TYPE_Q5_1 = 7,
  229. GGML_TYPE_Q8_0 = 8,
  230. GGML_TYPE_Q8_1 = 9,
  231. // k-quantizations
  232. GGML_TYPE_Q2_K = 10,
  233. GGML_TYPE_Q3_K = 11,
  234. GGML_TYPE_Q4_K = 12,
  235. GGML_TYPE_Q5_K = 13,
  236. GGML_TYPE_Q6_K = 14,
  237. GGML_TYPE_Q8_K = 15,
  238. GGML_TYPE_I8,
  239. GGML_TYPE_I16,
  240. GGML_TYPE_I32,
  241. GGML_TYPE_COUNT,
  242. };
  243. enum ggml_backend {
  244. GGML_BACKEND_CPU = 0,
  245. GGML_BACKEND_GPU = 10,
  246. GGML_BACKEND_GPU_SPLIT = 20,
  247. };
  248. // model file types
  249. enum ggml_ftype {
  250. GGML_FTYPE_UNKNOWN = -1,
  251. GGML_FTYPE_ALL_F32 = 0,
  252. GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
  253. GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
  254. GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
  255. GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
  256. GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
  257. GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
  258. GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
  259. GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
  260. GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors
  261. GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
  262. GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
  263. GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
  264. };
  265. // available tensor operations:
  266. enum ggml_op {
  267. GGML_OP_NONE = 0,
  268. GGML_OP_DUP,
  269. GGML_OP_ADD,
  270. GGML_OP_ADD1,
  271. GGML_OP_ACC,
  272. GGML_OP_SUB,
  273. GGML_OP_MUL,
  274. GGML_OP_DIV,
  275. GGML_OP_SQR,
  276. GGML_OP_SQRT,
  277. GGML_OP_LOG,
  278. GGML_OP_SUM,
  279. GGML_OP_SUM_ROWS,
  280. GGML_OP_MEAN,
  281. GGML_OP_REPEAT,
  282. GGML_OP_REPEAT_BACK,
  283. GGML_OP_ABS,
  284. GGML_OP_SGN,
  285. GGML_OP_NEG,
  286. GGML_OP_STEP,
  287. GGML_OP_RELU,
  288. GGML_OP_GELU,
  289. GGML_OP_SILU,
  290. GGML_OP_SILU_BACK,
  291. GGML_OP_NORM, // normalize
  292. GGML_OP_RMS_NORM,
  293. GGML_OP_RMS_NORM_BACK,
  294. GGML_OP_MUL_MAT,
  295. GGML_OP_OUT_PROD,
  296. GGML_OP_SCALE,
  297. GGML_OP_SET,
  298. GGML_OP_CPY,
  299. GGML_OP_CONT,
  300. GGML_OP_RESHAPE,
  301. GGML_OP_VIEW,
  302. GGML_OP_PERMUTE,
  303. GGML_OP_TRANSPOSE,
  304. GGML_OP_GET_ROWS,
  305. GGML_OP_GET_ROWS_BACK,
  306. GGML_OP_DIAG,
  307. GGML_OP_DIAG_MASK_INF,
  308. GGML_OP_DIAG_MASK_ZERO,
  309. GGML_OP_SOFT_MAX,
  310. GGML_OP_SOFT_MAX_BACK,
  311. GGML_OP_ROPE,
  312. GGML_OP_ROPE_BACK,
  313. GGML_OP_ALIBI,
  314. GGML_OP_CLAMP,
  315. GGML_OP_CONV_1D_1S,
  316. GGML_OP_CONV_1D_2S,
  317. GGML_OP_FLASH_ATTN,
  318. GGML_OP_FLASH_FF,
  319. GGML_OP_FLASH_ATTN_BACK,
  320. GGML_OP_MAP_UNARY,
  321. GGML_OP_MAP_BINARY,
  322. GGML_OP_CROSS_ENTROPY_LOSS,
  323. GGML_OP_CROSS_ENTROPY_LOSS_BACK,
  324. GGML_OP_COUNT,
  325. };
  326. // ggml object
  327. struct ggml_object {
  328. size_t offs;
  329. size_t size;
  330. struct ggml_object * next;
  331. char padding[8];
  332. };
  333. static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
  334. // n-dimensional tensor
  335. struct ggml_tensor {
  336. enum ggml_type type;
  337. enum ggml_backend backend;
  338. int n_dims;
  339. int64_t ne[GGML_MAX_DIMS]; // number of elements
  340. size_t nb[GGML_MAX_DIMS]; // stride in bytes:
  341. // nb[0] = sizeof(type)
  342. // nb[1] = nb[0] * ne[0] + padding
  343. // nb[i] = nb[i-1] * ne[i-1]
  344. // compute data
  345. enum ggml_op op;
  346. bool is_param;
  347. struct ggml_tensor * grad;
  348. struct ggml_tensor * src0;
  349. struct ggml_tensor * src1;
  350. struct ggml_tensor * opt[GGML_MAX_OPT];
  351. // thread scheduling
  352. int n_tasks;
  353. // performance
  354. int perf_runs;
  355. int64_t perf_cycles;
  356. int64_t perf_time_us;
  357. void * data;
  358. char name[GGML_MAX_NAME];
  359. void * extra; // extra things e.g. for ggml-cuda.cu
  360. char padding[4];
  361. };
  362. static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
  363. // computation graph
  364. struct ggml_cgraph {
  365. int n_nodes;
  366. int n_leafs;
  367. int n_threads;
  368. size_t work_size;
  369. struct ggml_tensor * work;
  370. struct ggml_tensor * nodes[GGML_MAX_NODES];
  371. struct ggml_tensor * grads[GGML_MAX_NODES];
  372. struct ggml_tensor * leafs[GGML_MAX_NODES];
  373. // performance
  374. int perf_runs;
  375. int64_t perf_cycles;
  376. int64_t perf_time_us;
  377. };
  378. // scratch buffer
  379. struct ggml_scratch {
  380. size_t offs;
  381. size_t size;
  382. void * data;
  383. };
  384. struct ggml_init_params {
  385. // memory pool
  386. size_t mem_size; // bytes
  387. void * mem_buffer; // if NULL, memory will be allocated internally
  388. bool no_alloc; // don't allocate memory for the tensor data
  389. };
  390. // compute types
  391. enum ggml_task_type {
  392. GGML_TASK_INIT = 0,
  393. GGML_TASK_COMPUTE,
  394. GGML_TASK_FINALIZE,
  395. };
  396. struct ggml_compute_params {
  397. enum ggml_task_type type;
  398. // ith = thread index, nth = number of threads
  399. int ith, nth;
  400. // work buffer for all threads
  401. size_t wsize;
  402. void * wdata;
  403. };
  404. // misc
  405. GGML_API void ggml_time_init(void); // call this once at the beginning of the program
  406. GGML_API int64_t ggml_time_ms(void);
  407. GGML_API int64_t ggml_time_us(void);
  408. GGML_API int64_t ggml_cycles(void);
  409. GGML_API int64_t ggml_cycles_per_ms(void);
  410. GGML_API void ggml_print_object (const struct ggml_object * obj);
  411. GGML_API void ggml_print_objects(const struct ggml_context * ctx);
  412. GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
  413. GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
  414. GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
  415. GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
  416. GGML_API int ggml_blck_size (enum ggml_type type);
  417. GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
  418. GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
  419. GGML_API const char * ggml_type_name(enum ggml_type type);
  420. GGML_API const char * ggml_op_name (enum ggml_op op);
  421. GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
  422. GGML_API bool ggml_is_quantized(enum ggml_type type);
  423. // TODO: temporary until model loading of ggml examples is refactored
  424. GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype);
  425. GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor);
  426. GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor);
  427. GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor);
  428. // use this to compute the memory overhead of a tensor
  429. GGML_API size_t ggml_tensor_overhead(void);
  430. // main
  431. GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
  432. GGML_API void ggml_free(struct ggml_context * ctx);
  433. GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
  434. GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
  435. GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
  436. GGML_API void * ggml_get_mem_buffer(struct ggml_context * ctx);
  437. GGML_API size_t ggml_get_mem_size (struct ggml_context * ctx);
  438. GGML_API struct ggml_tensor * ggml_new_tensor(
  439. struct ggml_context * ctx,
  440. enum ggml_type type,
  441. int n_dims,
  442. const int64_t *ne);
  443. GGML_API struct ggml_tensor * ggml_new_tensor_1d(
  444. struct ggml_context * ctx,
  445. enum ggml_type type,
  446. int64_t ne0);
  447. GGML_API struct ggml_tensor * ggml_new_tensor_2d(
  448. struct ggml_context * ctx,
  449. enum ggml_type type,
  450. int64_t ne0,
  451. int64_t ne1);
  452. GGML_API struct ggml_tensor * ggml_new_tensor_3d(
  453. struct ggml_context * ctx,
  454. enum ggml_type type,
  455. int64_t ne0,
  456. int64_t ne1,
  457. int64_t ne2);
  458. GGML_API struct ggml_tensor * ggml_new_tensor_4d(
  459. struct ggml_context * ctx,
  460. enum ggml_type type,
  461. int64_t ne0,
  462. int64_t ne1,
  463. int64_t ne2,
  464. int64_t ne3);
  465. GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
  466. GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
  467. GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
  468. GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
  469. GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
  470. GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
  471. GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
  472. GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
  473. GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
  474. GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
  475. GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
  476. GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
  477. GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
  478. GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
  479. GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor);
  480. GGML_API void ggml_set_name(struct ggml_tensor * tensor, const char * name);
  481. //
  482. // operations on tensors with backpropagation
  483. //
  484. GGML_API struct ggml_tensor * ggml_dup(
  485. struct ggml_context * ctx,
  486. struct ggml_tensor * a);
  487. GGML_API struct ggml_tensor * ggml_add(
  488. struct ggml_context * ctx,
  489. struct ggml_tensor * a,
  490. struct ggml_tensor * b);
  491. GGML_API struct ggml_tensor * ggml_add_inplace(
  492. struct ggml_context * ctx,
  493. struct ggml_tensor * a,
  494. struct ggml_tensor * b);
  495. GGML_API struct ggml_tensor * ggml_add1(
  496. struct ggml_context * ctx,
  497. struct ggml_tensor * a,
  498. struct ggml_tensor * b);
  499. GGML_API struct ggml_tensor * ggml_add1_inplace(
  500. struct ggml_context * ctx,
  501. struct ggml_tensor * a,
  502. struct ggml_tensor * b);
  503. GGML_API struct ggml_tensor * ggml_acc(
  504. struct ggml_context * ctx,
  505. struct ggml_tensor * a,
  506. struct ggml_tensor * b,
  507. size_t nb1,
  508. size_t nb2,
  509. size_t nb3,
  510. size_t offset);
  511. GGML_API struct ggml_tensor * ggml_acc_inplace(
  512. struct ggml_context * ctx,
  513. struct ggml_tensor * a,
  514. struct ggml_tensor * b,
  515. size_t nb1,
  516. size_t nb2,
  517. size_t nb3,
  518. size_t offset);
  519. GGML_API struct ggml_tensor * ggml_sub(
  520. struct ggml_context * ctx,
  521. struct ggml_tensor * a,
  522. struct ggml_tensor * b);
  523. GGML_API struct ggml_tensor * ggml_mul(
  524. struct ggml_context * ctx,
  525. struct ggml_tensor * a,
  526. struct ggml_tensor * b);
  527. GGML_API struct ggml_tensor * ggml_div(
  528. struct ggml_context * ctx,
  529. struct ggml_tensor * a,
  530. struct ggml_tensor * b);
  531. GGML_API struct ggml_tensor * ggml_sqr(
  532. struct ggml_context * ctx,
  533. struct ggml_tensor * a);
  534. GGML_API struct ggml_tensor * ggml_sqrt(
  535. struct ggml_context * ctx,
  536. struct ggml_tensor * a);
  537. GGML_API struct ggml_tensor * ggml_log(
  538. struct ggml_context * ctx,
  539. struct ggml_tensor * a);
  540. GGML_API struct ggml_tensor * ggml_log_inplace(
  541. struct ggml_context * ctx,
  542. struct ggml_tensor * a);
  543. // return scalar
  544. GGML_API struct ggml_tensor * ggml_sum(
  545. struct ggml_context * ctx,
  546. struct ggml_tensor * a);
  547. // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d]
  548. GGML_API struct ggml_tensor * ggml_sum_rows(
  549. struct ggml_context * ctx,
  550. struct ggml_tensor * a);
  551. // mean along rows
  552. GGML_API struct ggml_tensor * ggml_mean(
  553. struct ggml_context * ctx,
  554. struct ggml_tensor * a);
  555. // if a is the same shape as b, and a is not parameter, return a
  556. // otherwise, return a new tensor: repeat(a) to fit in b
  557. GGML_API struct ggml_tensor * ggml_repeat(
  558. struct ggml_context * ctx,
  559. struct ggml_tensor * a,
  560. struct ggml_tensor * b);
  561. GGML_API struct ggml_tensor * ggml_repeat_back(
  562. struct ggml_context * ctx,
  563. struct ggml_tensor * a,
  564. struct ggml_tensor * b);
  565. GGML_API struct ggml_tensor * ggml_abs(
  566. struct ggml_context * ctx,
  567. struct ggml_tensor * a);
  568. GGML_API struct ggml_tensor * ggml_sgn(
  569. struct ggml_context * ctx,
  570. struct ggml_tensor * a);
  571. GGML_API struct ggml_tensor * ggml_neg(
  572. struct ggml_context * ctx,
  573. struct ggml_tensor * a);
  574. GGML_API struct ggml_tensor * ggml_step(
  575. struct ggml_context * ctx,
  576. struct ggml_tensor * a);
  577. GGML_API struct ggml_tensor * ggml_relu(
  578. struct ggml_context * ctx,
  579. struct ggml_tensor * a);
  580. // TODO: double-check this computation is correct
  581. GGML_API struct ggml_tensor * ggml_gelu(
  582. struct ggml_context * ctx,
  583. struct ggml_tensor * a);
  584. GGML_API struct ggml_tensor * ggml_silu(
  585. struct ggml_context * ctx,
  586. struct ggml_tensor * a);
  587. // a - x
  588. // b - dy
  589. GGML_API struct ggml_tensor * ggml_silu_back(
  590. struct ggml_context * ctx,
  591. struct ggml_tensor * a,
  592. struct ggml_tensor * b);
  593. // normalize along rows
  594. // TODO: eps is hardcoded to 1e-5 for now
  595. GGML_API struct ggml_tensor * ggml_norm(
  596. struct ggml_context * ctx,
  597. struct ggml_tensor * a);
  598. GGML_API struct ggml_tensor * ggml_rms_norm(
  599. struct ggml_context * ctx,
  600. struct ggml_tensor * a);
  601. // a - x
  602. // b - dy
  603. GGML_API struct ggml_tensor * ggml_rms_norm_back(
  604. struct ggml_context * ctx,
  605. struct ggml_tensor * a,
  606. struct ggml_tensor * b);
  607. // A: n columns, m rows
  608. // B: n columns, p rows (i.e. we transpose it internally)
  609. // result is m columns, p rows
  610. GGML_API struct ggml_tensor * ggml_mul_mat(
  611. struct ggml_context * ctx,
  612. struct ggml_tensor * a,
  613. struct ggml_tensor * b);
  614. // A: m columns, n rows,
  615. // B: p columns, n rows,
  616. // result is m columns, p rows
  617. GGML_API struct ggml_tensor * ggml_out_prod(
  618. struct ggml_context * ctx,
  619. struct ggml_tensor * a,
  620. struct ggml_tensor * b);
  621. //
  622. // operations on tensors without backpropagation
  623. //
  624. GGML_API struct ggml_tensor * ggml_scale(
  625. struct ggml_context * ctx,
  626. struct ggml_tensor * a,
  627. struct ggml_tensor * b);
  628. // in-place, returns view(a)
  629. GGML_API struct ggml_tensor * ggml_scale_inplace(
  630. struct ggml_context * ctx,
  631. struct ggml_tensor * a,
  632. struct ggml_tensor * b);
  633. // b -> view(a,offset,nb1,nb2,3), return modified a
  634. GGML_API struct ggml_tensor * ggml_set(
  635. struct ggml_context * ctx,
  636. struct ggml_tensor * a,
  637. struct ggml_tensor * b,
  638. size_t nb1,
  639. size_t nb2,
  640. size_t nb3,
  641. size_t offset);
  642. // b -> view(a,offset,nb1,nb2,3), return view(a)
  643. GGML_API struct ggml_tensor * ggml_set_inplace(
  644. struct ggml_context * ctx,
  645. struct ggml_tensor * a,
  646. struct ggml_tensor * b,
  647. size_t nb1,
  648. size_t nb2,
  649. size_t nb3,
  650. size_t offset);
  651. GGML_API struct ggml_tensor * ggml_set_1d(
  652. struct ggml_context * ctx,
  653. struct ggml_tensor * a,
  654. struct ggml_tensor * b,
  655. size_t offset);
  656. GGML_API struct ggml_tensor * ggml_set_1d_inplace(
  657. struct ggml_context * ctx,
  658. struct ggml_tensor * a,
  659. struct ggml_tensor * b,
  660. size_t offset);
  661. // b -> view(a,offset,nb1,nb2,3), return modified a
  662. GGML_API struct ggml_tensor * ggml_set_2d(
  663. struct ggml_context * ctx,
  664. struct ggml_tensor * a,
  665. struct ggml_tensor * b,
  666. size_t nb1,
  667. size_t offset);
  668. // b -> view(a,offset,nb1,nb2,3), return view(a)
  669. GGML_API struct ggml_tensor * ggml_set_2d_inplace(
  670. struct ggml_context * ctx,
  671. struct ggml_tensor * a,
  672. struct ggml_tensor * b,
  673. size_t nb1,
  674. size_t offset);
  675. // a -> b, return view(b)
  676. GGML_API struct ggml_tensor * ggml_cpy(
  677. struct ggml_context * ctx,
  678. struct ggml_tensor * a,
  679. struct ggml_tensor * b);
  680. // make contiguous
  681. GGML_API struct ggml_tensor * ggml_cont(
  682. struct ggml_context * ctx,
  683. struct ggml_tensor * a);
  684. // return view(a), b specifies the new shape
  685. // TODO: when we start computing gradient, make a copy instead of view
  686. GGML_API struct ggml_tensor * ggml_reshape(
  687. struct ggml_context * ctx,
  688. struct ggml_tensor * a,
  689. struct ggml_tensor * b);
  690. // return view(a)
  691. // TODO: when we start computing gradient, make a copy instead of view
  692. GGML_API struct ggml_tensor * ggml_reshape_1d(
  693. struct ggml_context * ctx,
  694. struct ggml_tensor * a,
  695. int64_t ne0);
  696. GGML_API struct ggml_tensor * ggml_reshape_2d(
  697. struct ggml_context * ctx,
  698. struct ggml_tensor * a,
  699. int64_t ne0,
  700. int64_t ne1);
  701. // return view(a)
  702. // TODO: when we start computing gradient, make a copy instead of view
  703. GGML_API struct ggml_tensor * ggml_reshape_3d(
  704. struct ggml_context * ctx,
  705. struct ggml_tensor * a,
  706. int64_t ne0,
  707. int64_t ne1,
  708. int64_t ne2);
  709. GGML_API struct ggml_tensor * ggml_reshape_4d(
  710. struct ggml_context * ctx,
  711. struct ggml_tensor * a,
  712. int64_t ne0,
  713. int64_t ne1,
  714. int64_t ne2,
  715. int64_t ne3);
  716. // offset in bytes
  717. GGML_API struct ggml_tensor * ggml_view_1d(
  718. struct ggml_context * ctx,
  719. struct ggml_tensor * a,
  720. int64_t ne0,
  721. size_t offset);
  722. GGML_API struct ggml_tensor * ggml_view_2d(
  723. struct ggml_context * ctx,
  724. struct ggml_tensor * a,
  725. int64_t ne0,
  726. int64_t ne1,
  727. size_t nb1, // row stride in bytes
  728. size_t offset);
  729. GGML_API struct ggml_tensor * ggml_view_3d(
  730. struct ggml_context * ctx,
  731. struct ggml_tensor * a,
  732. int64_t ne0,
  733. int64_t ne1,
  734. int64_t ne2,
  735. size_t nb1, // row stride in bytes
  736. size_t nb2, // slice stride in bytes
  737. size_t offset);
  738. GGML_API struct ggml_tensor * ggml_view_4d(
  739. struct ggml_context * ctx,
  740. struct ggml_tensor * a,
  741. int64_t ne0,
  742. int64_t ne1,
  743. int64_t ne2,
  744. int64_t ne3,
  745. size_t nb1, // row stride in bytes
  746. size_t nb2, // slice stride in bytes
  747. size_t nb3,
  748. size_t offset);
  749. GGML_API struct ggml_tensor * ggml_permute(
  750. struct ggml_context * ctx,
  751. struct ggml_tensor * a,
  752. int axis0,
  753. int axis1,
  754. int axis2,
  755. int axis3);
  756. // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
  757. GGML_API struct ggml_tensor * ggml_transpose(
  758. struct ggml_context * ctx,
  759. struct ggml_tensor * a);
  760. GGML_API struct ggml_tensor * ggml_get_rows(
  761. struct ggml_context * ctx,
  762. struct ggml_tensor * a,
  763. struct ggml_tensor * b);
  764. GGML_API struct ggml_tensor * ggml_get_rows_back(
  765. struct ggml_context * ctx,
  766. struct ggml_tensor * a,
  767. struct ggml_tensor * b,
  768. struct ggml_tensor * c);
  769. GGML_API struct ggml_tensor * ggml_diag(
  770. struct ggml_context * ctx,
  771. struct ggml_tensor * a);
  772. // set elements above the diagonal to -INF
  773. GGML_API struct ggml_tensor * ggml_diag_mask_inf(
  774. struct ggml_context * ctx,
  775. struct ggml_tensor * a,
  776. int n_past);
  777. // in-place, returns view(a)
  778. GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace(
  779. struct ggml_context * ctx,
  780. struct ggml_tensor * a,
  781. int n_past);
  782. // set elements above the diagonal to 0
  783. GGML_API struct ggml_tensor * ggml_diag_mask_zero(
  784. struct ggml_context * ctx,
  785. struct ggml_tensor * a,
  786. int n_past);
  787. // in-place, returns view(a)
  788. GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace(
  789. struct ggml_context * ctx,
  790. struct ggml_tensor * a,
  791. int n_past);
  792. GGML_API struct ggml_tensor * ggml_soft_max(
  793. struct ggml_context * ctx,
  794. struct ggml_tensor * a);
  795. // in-place, returns view(a)
  796. GGML_API struct ggml_tensor * ggml_soft_max_inplace(
  797. struct ggml_context * ctx,
  798. struct ggml_tensor * a);
  799. GGML_API struct ggml_tensor * ggml_soft_max_back(
  800. struct ggml_context * ctx,
  801. struct ggml_tensor * a,
  802. struct ggml_tensor * b);
  803. // in-place, returns view(a)
  804. GGML_API struct ggml_tensor * ggml_soft_max_back_inplace(
  805. struct ggml_context * ctx,
  806. struct ggml_tensor * a,
  807. struct ggml_tensor * b);
  808. // rotary position embedding
  809. // if mode & 1 == 1, skip n_past elements
  810. // if mode & 2 == 1, GPT-NeoX style
  811. // TODO: avoid creating a new tensor every time
  812. GGML_API struct ggml_tensor * ggml_rope(
  813. struct ggml_context * ctx,
  814. struct ggml_tensor * a,
  815. int n_past,
  816. int n_dims,
  817. int mode);
  818. // in-place, returns view(a)
  819. GGML_API struct ggml_tensor * ggml_rope_inplace(
  820. struct ggml_context * ctx,
  821. struct ggml_tensor * a,
  822. int n_past,
  823. int n_dims,
  824. int mode);
  825. // rotary position embedding backward, i.e compute dx from dy
  826. // a - dy
  827. GGML_API struct ggml_tensor * ggml_rope_back(
  828. struct ggml_context * ctx,
  829. struct ggml_tensor * a,
  830. int n_past,
  831. int n_dims,
  832. int mode);
  833. // alibi position embedding
  834. // in-place, returns view(a)
  835. struct ggml_tensor * ggml_alibi(
  836. struct ggml_context * ctx,
  837. struct ggml_tensor * a,
  838. int n_past,
  839. int n_head,
  840. float bias_max);
  841. // clamp
  842. // in-place, returns view(a)
  843. struct ggml_tensor * ggml_clamp(
  844. struct ggml_context * ctx,
  845. struct ggml_tensor * a,
  846. float min,
  847. float max);
  848. // padding = 1
  849. // TODO: we don't support extra parameters for now
  850. // that's why we are hard-coding the stride, padding, and dilation
  851. // not great ..
  852. GGML_API struct ggml_tensor * ggml_conv_1d_1s(
  853. struct ggml_context * ctx,
  854. struct ggml_tensor * a,
  855. struct ggml_tensor * b);
  856. GGML_API struct ggml_tensor * ggml_conv_1d_2s(
  857. struct ggml_context * ctx,
  858. struct ggml_tensor * a,
  859. struct ggml_tensor * b);
  860. GGML_API struct ggml_tensor * ggml_flash_attn(
  861. struct ggml_context * ctx,
  862. struct ggml_tensor * q,
  863. struct ggml_tensor * k,
  864. struct ggml_tensor * v,
  865. bool masked);
  866. GGML_API struct ggml_tensor * ggml_flash_attn_back(
  867. struct ggml_context * ctx,
  868. struct ggml_tensor * q,
  869. struct ggml_tensor * k,
  870. struct ggml_tensor * v,
  871. struct ggml_tensor * d,
  872. bool masked);
  873. GGML_API struct ggml_tensor * ggml_flash_ff(
  874. struct ggml_context * ctx,
  875. struct ggml_tensor * a,
  876. struct ggml_tensor * b0,
  877. struct ggml_tensor * b1,
  878. struct ggml_tensor * c0,
  879. struct ggml_tensor * c1);
  880. // Mapping operations
  881. typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
  882. typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
  883. GGML_API struct ggml_tensor * ggml_map_unary_f32(
  884. struct ggml_context * ctx,
  885. struct ggml_tensor * a,
  886. ggml_unary_op_f32_t fun);
  887. GGML_API struct ggml_tensor * ggml_map_binary_f32(
  888. struct ggml_context * ctx,
  889. struct ggml_tensor * a,
  890. struct ggml_tensor * b,
  891. ggml_binary_op_f32_t fun);
  892. // loss function
  893. GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
  894. struct ggml_context * ctx,
  895. struct ggml_tensor * a,
  896. struct ggml_tensor * b);
  897. GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
  898. struct ggml_context * ctx,
  899. struct ggml_tensor * a,
  900. struct ggml_tensor * b,
  901. struct ggml_tensor * c);
  902. //
  903. // automatic differentiation
  904. //
  905. GGML_API void ggml_set_param(
  906. struct ggml_context * ctx,
  907. struct ggml_tensor * tensor);
  908. GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
  909. GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
  910. GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
  911. GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
  912. GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
  913. GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name);
  914. GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname);
  915. GGML_API struct ggml_cgraph ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval);
  916. // print info and performance information for the graph
  917. GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph);
  918. // dump the graph into a file using the dot format
  919. GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
  920. //
  921. // optimization
  922. //
  923. // optimization methods
  924. enum ggml_opt_type {
  925. GGML_OPT_ADAM,
  926. GGML_OPT_LBFGS,
  927. };
  928. // linesearch methods
  929. enum ggml_linesearch {
  930. GGML_LINESEARCH_DEFAULT = 1,
  931. GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
  932. GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
  933. GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
  934. };
  935. // optimization return values
  936. enum ggml_opt_result {
  937. GGML_OPT_OK = 0,
  938. GGML_OPT_DID_NOT_CONVERGE,
  939. GGML_OPT_NO_CONTEXT,
  940. GGML_OPT_INVALID_WOLFE,
  941. GGML_OPT_FAIL,
  942. GGML_LINESEARCH_FAIL = -128,
  943. GGML_LINESEARCH_MINIMUM_STEP,
  944. GGML_LINESEARCH_MAXIMUM_STEP,
  945. GGML_LINESEARCH_MAXIMUM_ITERATIONS,
  946. GGML_LINESEARCH_INVALID_PARAMETERS,
  947. };
  948. // optimization parameters
  949. //
  950. // see ggml.c (ggml_opt_default_params) for default values
  951. //
  952. struct ggml_opt_params {
  953. enum ggml_opt_type type;
  954. int n_threads;
  955. // delta-based convergence test
  956. //
  957. // if past == 0 - disabled
  958. // if past > 0:
  959. // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
  960. //
  961. int past;
  962. float delta;
  963. // maximum number of iterations without improvement
  964. //
  965. // if 0 - disabled
  966. // if > 0:
  967. // assume convergence if no cost improvement in this number of iterations
  968. //
  969. int max_no_improvement;
  970. bool print_forward_graph;
  971. bool print_backward_graph;
  972. // ADAM parameters
  973. struct {
  974. int n_iter;
  975. float sched; // schedule multiplier (fixed, decay or warmup)
  976. float decay; // weight decay for AdamW, use 0.0f to disable
  977. float alpha; // learning rate
  978. float beta1;
  979. float beta2;
  980. float eps; // epsilon for numerical stability
  981. float eps_f; // epsilon for convergence test
  982. float eps_g; // epsilon for convergence test
  983. } adam;
  984. // LBFGS parameters
  985. struct {
  986. int m; // number of corrections to approximate the inv. Hessian
  987. int n_iter;
  988. int max_linesearch;
  989. float eps; // convergence tolerance
  990. float ftol; // line search tolerance
  991. float wolfe;
  992. float min_step;
  993. float max_step;
  994. enum ggml_linesearch linesearch;
  995. } lbfgs;
  996. };
  997. struct ggml_opt_context {
  998. struct ggml_context * ctx;
  999. struct ggml_opt_params params;
  1000. int iter;
  1001. int64_t nx; // number of parameter elements
  1002. bool just_initialized;
  1003. struct {
  1004. struct ggml_tensor * x; // view of the parameters
  1005. struct ggml_tensor * g1; // gradient
  1006. struct ggml_tensor * g2; // gradient squared
  1007. struct ggml_tensor * m; // first moment
  1008. struct ggml_tensor * v; // second moment
  1009. struct ggml_tensor * mh; // first moment hat
  1010. struct ggml_tensor * vh; // second moment hat
  1011. struct ggml_tensor * pf; // past function values
  1012. float fx_best;
  1013. float fx_prev;
  1014. int n_no_improvement;
  1015. } adam;
  1016. struct {
  1017. struct ggml_tensor * x; // current parameters
  1018. struct ggml_tensor * xp; // previous parameters
  1019. struct ggml_tensor * g; // current gradient
  1020. struct ggml_tensor * gp; // previous gradient
  1021. struct ggml_tensor * d; // search direction
  1022. struct ggml_tensor * pf; // past function values
  1023. struct ggml_tensor * lmal; // the L-BFGS memory alpha
  1024. struct ggml_tensor * lmys; // the L-BFGS memory ys
  1025. struct ggml_tensor * lms; // the L-BFGS memory s
  1026. struct ggml_tensor * lmy; // the L-BFGS memory y
  1027. float fx_best;
  1028. float step;
  1029. int j;
  1030. int k;
  1031. int end;
  1032. int n_no_improvement;
  1033. } lbfgs;
  1034. };
  1035. GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
  1036. // optimize the function defined by the tensor f
  1037. GGML_API enum ggml_opt_result ggml_opt(
  1038. struct ggml_context * ctx,
  1039. struct ggml_opt_params params,
  1040. struct ggml_tensor * f);
  1041. // initialize optimizer context
  1042. GGML_API void ggml_opt_init(
  1043. struct ggml_context * ctx,
  1044. struct ggml_opt_context * opt,
  1045. struct ggml_opt_params params,
  1046. int64_t nx);
  1047. // continue optimizing the function defined by the tensor f
  1048. GGML_API enum ggml_opt_result ggml_opt_resume(
  1049. struct ggml_context * ctx,
  1050. struct ggml_opt_context * opt,
  1051. struct ggml_tensor * f);
  1052. // continue optimizing the function defined by the tensor f
  1053. GGML_API enum ggml_opt_result ggml_opt_resume_g(
  1054. struct ggml_context * ctx,
  1055. struct ggml_opt_context * opt,
  1056. struct ggml_tensor * f,
  1057. struct ggml_cgraph * gf,
  1058. struct ggml_cgraph * gb);
  1059. //
  1060. // quantization
  1061. //
  1062. GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
  1063. GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
  1064. GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist);
  1065. GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist);
  1066. GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist);
  1067. GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
  1068. //
  1069. // system info
  1070. //
  1071. GGML_API int ggml_cpu_has_avx (void);
  1072. GGML_API int ggml_cpu_has_avx2 (void);
  1073. GGML_API int ggml_cpu_has_avx512 (void);
  1074. GGML_API int ggml_cpu_has_avx512_vbmi(void);
  1075. GGML_API int ggml_cpu_has_avx512_vnni(void);
  1076. GGML_API int ggml_cpu_has_fma (void);
  1077. GGML_API int ggml_cpu_has_neon (void);
  1078. GGML_API int ggml_cpu_has_arm_fma (void);
  1079. GGML_API int ggml_cpu_has_f16c (void);
  1080. GGML_API int ggml_cpu_has_fp16_va (void);
  1081. GGML_API int ggml_cpu_has_wasm_simd (void);
  1082. GGML_API int ggml_cpu_has_blas (void);
  1083. GGML_API int ggml_cpu_has_cublas (void);
  1084. GGML_API int ggml_cpu_has_clblast (void);
  1085. GGML_API int ggml_cpu_has_gpublas (void);
  1086. GGML_API int ggml_cpu_has_sse3 (void);
  1087. GGML_API int ggml_cpu_has_vsx (void);
  1088. //
  1089. // Internal types and functions exposed for tests and benchmarks
  1090. //
  1091. #ifdef __cplusplus
  1092. // restrict not standard in C++
  1093. #define GGML_RESTRICT
  1094. #else
  1095. #define GGML_RESTRICT restrict
  1096. #endif
  1097. typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
  1098. typedef void (*quantize_row_q_t) (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
  1099. typedef void (*vec_dot_q_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
  1100. typedef struct {
  1101. dequantize_row_q_t dequantize_row_q;
  1102. quantize_row_q_t quantize_row_q;
  1103. quantize_row_q_t quantize_row_q_reference;
  1104. quantize_row_q_t quantize_row_q_dot;
  1105. vec_dot_q_t vec_dot_q;
  1106. enum ggml_type vec_dot_type;
  1107. } quantize_fns_t;
  1108. quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
  1109. #ifdef __cplusplus
  1110. }
  1111. #endif