ggml.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866
  1. #pragma once
  2. //
  3. // GGML Tensor Library
  4. //
  5. // This documentation is still a work in progress.
  6. // If you wish some specific topics to be covered, feel free to drop a comment:
  7. //
  8. // https://github.com/ggerganov/whisper.cpp/issues/40
  9. //
  10. // ## Overview
  11. //
  12. // This library implements:
  13. //
  14. // - a set of tensor operations
  15. // - automatic differentiation
  16. // - basic optimization algorithms
  17. //
  18. // The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes,
  19. // but is not limited to, the following:
  20. //
  21. // - linear regression
  22. // - support vector machines
  23. // - neural networks
  24. //
  25. // The library allows the user to define a certain function using the available tensor operations. This function
  26. // definition is represented internally via a computation graph. Each tensor operation in the function definition
  27. // corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the
  28. // function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized
  29. // using one of the available optimization algorithms.
  30. //
  31. // For example, here we define the function: f(x) = a*x^2 + b
  32. //
  33. // {
  34. // struct ggml_init_params params = {
  35. // .mem_size = 16*1024*1024,
  36. // .mem_buffer = NULL,
  37. // };
  38. //
  39. // // memory allocation happens here
  40. // struct ggml_context * ctx = ggml_init(params);
  41. //
  42. // struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  43. //
  44. // ggml_set_param(ctx, x); // x is an input variable
  45. //
  46. // struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  47. // struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1);
  48. // struct ggml_tensor * x2 = ggml_mul(ctx, x, x);
  49. // struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b);
  50. //
  51. // ...
  52. // }
  53. //
  54. // Notice that the function definition above does not involve any actual computation. The computation is performed only
  55. // when the user explicitly requests it. For example, to compute the function's value at x = 2.0:
  56. //
  57. // {
  58. // ...
  59. //
  60. // struct ggml_cgraph gf = ggml_build_forward(f);
  61. //
  62. // // set the input variable and parameter values
  63. // ggml_set_f32(x, 2.0f);
  64. // ggml_set_f32(a, 3.0f);
  65. // ggml_set_f32(b, 4.0f);
  66. //
  67. // ggml_graph_compute(ctx0, &gf);
  68. //
  69. // printf("f = %f\n", ggml_get_f32_1d(f, 0));
  70. //
  71. // ...
  72. // }
  73. //
  74. // The actual computation is performed in the ggml_graph_compute() function.
  75. //
  76. // The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the
  77. // ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know
  78. // in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory
  79. // and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was
  80. // actually needed.
  81. //
  82. // The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic
  83. // differentiation and optimization algorithms.
  84. //
  85. // The described approach allows to define the function graph once and then compute its forward or backward graphs
  86. // multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way
  87. // the user can avoid the memory allocation overhead at runtime.
  88. //
  89. // The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class
  90. // citizens, but in theory the library can be extended to support FP8 and integer data types.
  91. //
  92. // Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary
  93. // and binary operations. Most of the available operations fall into one of these two categories. With time, it became
  94. // clear that the library needs to support more complex operations. The way to support these operations is not clear
  95. // yet, but a few examples are demonstrated in the following operations:
  96. //
  97. // - ggml_permute()
  98. // - ggml_conv_1d_1s()
  99. // - ggml_conv_1d_2s()
  100. //
  101. // For each tensor operator, the library implements a forward and backward computation function. The forward function
  102. // computes the output tensor value given the input tensor values. The backward function computes the adjoint of the
  103. // input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a
  104. // calculus class, or watch the following video:
  105. //
  106. // What is Automatic Differentiation?
  107. // https://www.youtube.com/watch?v=wG_nF1awSSY
  108. //
  109. //
  110. // ## Tensor data (struct ggml_tensor)
  111. //
  112. // The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of
  113. // the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains
  114. // pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example:
  115. //
  116. // {
  117. // struct ggml_tensor * c = ggml_add(ctx, a, b);
  118. //
  119. // assert(c->src[0] == a);
  120. // assert(c->src[1] == b);
  121. // }
  122. //
  123. // The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the
  124. // number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows
  125. // to store tensors that are not contiguous in memory, which is useful for operations such as transposition and
  126. // permutation. All tensor operations have to take the stride into account and not assume that the tensor is
  127. // contiguous in memory.
  128. //
  129. // The data of the tensor is accessed via the "data" pointer. For example:
  130. //
  131. // {
  132. // struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, 3);
  133. //
  134. // // a[1, 2] = 1.0f;
  135. // *(float *) ((char *) a->data + 2*a->nb[1] + 1*a->nb[0]) = 1.0f;
  136. //
  137. // // a[2, 0] = 2.0f;
  138. // *(float *) ((char *) a->data + 0*a->nb[1] + 2*a->nb[0]) = 2.0f;
  139. //
  140. // ...
  141. // }
  142. //
  143. // Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used.
  144. //
  145. // ## The matrix multiplication operator (ggml_mul_mat)
  146. //
  147. // TODO
  148. //
  149. //
  150. // ## Multi-threading
  151. //
  152. // TODO
  153. //
  154. //
  155. // ## Overview of ggml.c
  156. //
  157. // TODO
  158. //
  159. //
  160. // ## SIMD optimizations
  161. //
  162. // TODO
  163. //
  164. //
  165. // ## Debugging ggml
  166. //
  167. // TODO
  168. //
  169. //
  170. #ifdef __cplusplus
  171. extern "C" {
  172. #endif
  173. #include <stdint.h>
  174. #include <stddef.h>
  175. #include <stdbool.h>
  176. #define GGML_MAX_DIMS 4
  177. #define GGML_MAX_NODES 4096
  178. #define GGML_MAX_PARAMS 16
  179. #define GGML_MAX_CONTEXTS 64
  180. #define GGML_MAX_OPT 4
  181. #define GGML_DEFAULT_N_THREADS 4
  182. #ifdef __ARM_NEON
  183. // we use the built-in 16-bit float type
  184. typedef __fp16 ggml_fp16_t;
  185. #else
  186. typedef uint16_t ggml_fp16_t;
  187. #endif
  188. // convert FP16 <-> FP32
  189. float ggml_fp16_to_fp32(ggml_fp16_t x);
  190. ggml_fp16_t ggml_fp32_to_fp16(float x);
  191. struct ggml_object;
  192. struct ggml_context;
  193. enum ggml_type {
  194. // explicitly numbered values are used in llama.cpp files
  195. GGML_TYPE_F32 = 0,
  196. GGML_TYPE_F16 = 1,
  197. GGML_TYPE_Q4_0 = 2,
  198. GGML_TYPE_Q4_1 = 3,
  199. GGML_TYPE_Q4_2 = 4,
  200. GGML_TYPE_Q4_3 = 5,
  201. GGML_TYPE_Q8_0 = 6,
  202. GGML_TYPE_I8,
  203. GGML_TYPE_I16,
  204. GGML_TYPE_I32,
  205. GGML_TYPE_COUNT,
  206. };
  207. // available tensor operations:
  208. enum ggml_op {
  209. GGML_OP_NONE = 0,
  210. GGML_OP_DUP,
  211. GGML_OP_ADD,
  212. GGML_OP_SUB,
  213. GGML_OP_MUL,
  214. GGML_OP_DIV,
  215. GGML_OP_SQR,
  216. GGML_OP_SQRT,
  217. GGML_OP_SUM,
  218. GGML_OP_MEAN,
  219. GGML_OP_REPEAT,
  220. GGML_OP_ABS,
  221. GGML_OP_SGN,
  222. GGML_OP_NEG,
  223. GGML_OP_STEP,
  224. GGML_OP_RELU,
  225. GGML_OP_GELU,
  226. GGML_OP_SILU,
  227. GGML_OP_NORM, // normalize
  228. GGML_OP_RMS_NORM,
  229. GGML_OP_MUL_MAT,
  230. GGML_OP_SCALE,
  231. GGML_OP_CPY,
  232. GGML_OP_CONT,
  233. GGML_OP_RESHAPE,
  234. GGML_OP_VIEW,
  235. GGML_OP_PERMUTE,
  236. GGML_OP_TRANSPOSE,
  237. GGML_OP_GET_ROWS,
  238. GGML_OP_DIAG_MASK_INF,
  239. GGML_OP_SOFT_MAX,
  240. GGML_OP_ROPE,
  241. GGML_OP_CONV_1D_1S,
  242. GGML_OP_CONV_1D_2S,
  243. GGML_OP_FLASH_ATTN,
  244. GGML_OP_FLASH_FF,
  245. GGML_OP_MAP_UNARY,
  246. GGML_OP_MAP_BINARY,
  247. GGML_OP_COUNT,
  248. };
  249. // ggml object
  250. struct ggml_object {
  251. size_t offs;
  252. size_t size;
  253. struct ggml_object * next;
  254. char padding[8];
  255. };
  256. static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object);
  257. // n-dimensional tensor
  258. struct ggml_tensor {
  259. enum ggml_type type;
  260. int n_dims;
  261. int64_t ne[GGML_MAX_DIMS]; // number of elements
  262. size_t nb[GGML_MAX_DIMS]; // stride in bytes:
  263. // nb[0] = sizeof(type)
  264. // nb[1] = nb[0] * ne[0] + padding
  265. // nb[i] = nb[i-1] * ne[i-1]
  266. // compute data
  267. enum ggml_op op;
  268. bool is_param;
  269. struct ggml_tensor * grad;
  270. struct ggml_tensor * src0;
  271. struct ggml_tensor * src1;
  272. struct ggml_tensor * opt[GGML_MAX_OPT];
  273. // thread scheduling
  274. int n_tasks;
  275. // performance
  276. int perf_runs;
  277. int64_t perf_cycles;
  278. int64_t perf_time_us;
  279. void * data;
  280. char padding[8];
  281. };
  282. // computation graph
  283. struct ggml_cgraph {
  284. int n_nodes;
  285. int n_leafs;
  286. int n_threads;
  287. size_t work_size;
  288. struct ggml_tensor * work;
  289. struct ggml_tensor * nodes[GGML_MAX_NODES];
  290. struct ggml_tensor * grads[GGML_MAX_NODES];
  291. struct ggml_tensor * leafs[GGML_MAX_NODES];
  292. // performance
  293. int perf_runs;
  294. int64_t perf_cycles;
  295. int64_t perf_time_us;
  296. };
  297. // scratch buffer
  298. struct ggml_scratch {
  299. size_t offs;
  300. size_t size;
  301. void * data;
  302. };
  303. struct ggml_init_params {
  304. // memory pool
  305. size_t mem_size; // bytes
  306. void * mem_buffer; // if NULL, memory will be allocated internally
  307. bool no_alloc; // don't allocate memory for the tensor data
  308. };
  309. void ggml_time_init(void); // call this once at the beginning of the program
  310. int64_t ggml_time_ms(void);
  311. int64_t ggml_time_us(void);
  312. int64_t ggml_cycles(void);
  313. int64_t ggml_cycles_per_ms(void);
  314. void ggml_print_object (const struct ggml_object * obj);
  315. void ggml_print_objects(const struct ggml_context * ctx);
  316. int64_t ggml_nelements(const struct ggml_tensor * tensor);
  317. size_t ggml_nbytes (const struct ggml_tensor * tensor);
  318. int ggml_blck_size (enum ggml_type type);
  319. size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
  320. float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float
  321. const char * ggml_type_name(enum ggml_type type);
  322. size_t ggml_element_size(const struct ggml_tensor * tensor);
  323. bool ggml_is_quantized(enum ggml_type type);
  324. struct ggml_context * ggml_init(struct ggml_init_params params);
  325. void ggml_free(struct ggml_context * ctx);
  326. size_t ggml_used_mem(const struct ggml_context * ctx);
  327. size_t ggml_set_scratch(struct ggml_context * ctx, struct ggml_scratch scratch);
  328. struct ggml_tensor * ggml_new_tensor(
  329. struct ggml_context * ctx,
  330. enum ggml_type type,
  331. int n_dims,
  332. const int64_t *ne);
  333. struct ggml_tensor * ggml_new_tensor_1d(
  334. struct ggml_context * ctx,
  335. enum ggml_type type,
  336. int64_t ne0);
  337. struct ggml_tensor * ggml_new_tensor_2d(
  338. struct ggml_context * ctx,
  339. enum ggml_type type,
  340. int64_t ne0,
  341. int64_t ne1);
  342. struct ggml_tensor * ggml_new_tensor_3d(
  343. struct ggml_context * ctx,
  344. enum ggml_type type,
  345. int64_t ne0,
  346. int64_t ne1,
  347. int64_t ne2);
  348. struct ggml_tensor * ggml_new_tensor_4d(
  349. struct ggml_context * ctx,
  350. enum ggml_type type,
  351. int64_t ne0,
  352. int64_t ne1,
  353. int64_t ne2,
  354. int64_t ne3);
  355. struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
  356. struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
  357. struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
  358. struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, const struct ggml_tensor * src);
  359. struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
  360. struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
  361. struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
  362. int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
  363. void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
  364. float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
  365. void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
  366. void * ggml_get_data (const struct ggml_tensor * tensor);
  367. float * ggml_get_data_f32(const struct ggml_tensor * tensor);
  368. //
  369. // operations on tensors with backpropagation
  370. //
  371. struct ggml_tensor * ggml_dup(
  372. struct ggml_context * ctx,
  373. struct ggml_tensor * a);
  374. struct ggml_tensor * ggml_add(
  375. struct ggml_context * ctx,
  376. struct ggml_tensor * a,
  377. struct ggml_tensor * b);
  378. struct ggml_tensor * ggml_add_inplace(
  379. struct ggml_context * ctx,
  380. struct ggml_tensor * a,
  381. struct ggml_tensor * b);
  382. struct ggml_tensor * ggml_sub(
  383. struct ggml_context * ctx,
  384. struct ggml_tensor * a,
  385. struct ggml_tensor * b);
  386. struct ggml_tensor * ggml_mul(
  387. struct ggml_context * ctx,
  388. struct ggml_tensor * a,
  389. struct ggml_tensor * b);
  390. struct ggml_tensor * ggml_div(
  391. struct ggml_context * ctx,
  392. struct ggml_tensor * a,
  393. struct ggml_tensor * b);
  394. struct ggml_tensor * ggml_sqr(
  395. struct ggml_context * ctx,
  396. struct ggml_tensor * a);
  397. struct ggml_tensor * ggml_sqrt(
  398. struct ggml_context * ctx,
  399. struct ggml_tensor * a);
  400. // return scalar
  401. // TODO: compute sum along rows
  402. struct ggml_tensor * ggml_sum(
  403. struct ggml_context * ctx,
  404. struct ggml_tensor * a);
  405. // mean along rows
  406. struct ggml_tensor * ggml_mean(
  407. struct ggml_context * ctx,
  408. struct ggml_tensor * a);
  409. // if a is the same shape as b, and a is not parameter, return a
  410. // otherwise, return a new tensor: repeat(a) to fit in b
  411. struct ggml_tensor * ggml_repeat(
  412. struct ggml_context * ctx,
  413. struct ggml_tensor * a,
  414. struct ggml_tensor * b);
  415. struct ggml_tensor * ggml_abs(
  416. struct ggml_context * ctx,
  417. struct ggml_tensor * a);
  418. struct ggml_tensor * ggml_sgn(
  419. struct ggml_context * ctx,
  420. struct ggml_tensor * a);
  421. struct ggml_tensor * ggml_neg(
  422. struct ggml_context * ctx,
  423. struct ggml_tensor * a);
  424. struct ggml_tensor * ggml_step(
  425. struct ggml_context * ctx,
  426. struct ggml_tensor * a);
  427. struct ggml_tensor * ggml_relu(
  428. struct ggml_context * ctx,
  429. struct ggml_tensor * a);
  430. // TODO: double-check this computation is correct
  431. struct ggml_tensor * ggml_gelu(
  432. struct ggml_context * ctx,
  433. struct ggml_tensor * a);
  434. struct ggml_tensor * ggml_silu(
  435. struct ggml_context * ctx,
  436. struct ggml_tensor * a);
  437. // normalize along rows
  438. // TODO: eps is hardcoded to 1e-5 for now
  439. struct ggml_tensor * ggml_norm(
  440. struct ggml_context * ctx,
  441. struct ggml_tensor * a);
  442. struct ggml_tensor * ggml_rms_norm(
  443. struct ggml_context * ctx,
  444. struct ggml_tensor * a);
  445. // A: m rows, n columns
  446. // B: p rows, n columns (i.e. we transpose it internally)
  447. // result is m columns, p rows
  448. struct ggml_tensor * ggml_mul_mat(
  449. struct ggml_context * ctx,
  450. struct ggml_tensor * a,
  451. struct ggml_tensor * b);
  452. //
  453. // operations on tensors without backpropagation
  454. //
  455. // in-place, returns view(a)
  456. struct ggml_tensor * ggml_scale(
  457. struct ggml_context * ctx,
  458. struct ggml_tensor * a,
  459. struct ggml_tensor * b);
  460. // a -> b, return view(b)
  461. struct ggml_tensor * ggml_cpy(
  462. struct ggml_context * ctx,
  463. struct ggml_tensor * a,
  464. struct ggml_tensor * b);
  465. // make contiguous
  466. struct ggml_tensor * ggml_cont(
  467. struct ggml_context * ctx,
  468. struct ggml_tensor * a);
  469. // return view(a), b specifies the new shape
  470. // TODO: when we start computing gradient, make a copy instead of view
  471. struct ggml_tensor * ggml_reshape(
  472. struct ggml_context * ctx,
  473. struct ggml_tensor * a,
  474. struct ggml_tensor * b);
  475. // return view(a)
  476. // TODO: when we start computing gradient, make a copy instead of view
  477. struct ggml_tensor * ggml_reshape_2d(
  478. struct ggml_context * ctx,
  479. struct ggml_tensor * a,
  480. int64_t ne0,
  481. int64_t ne1);
  482. // return view(a)
  483. // TODO: when we start computing gradient, make a copy instead of view
  484. struct ggml_tensor * ggml_reshape_3d(
  485. struct ggml_context * ctx,
  486. struct ggml_tensor * a,
  487. int64_t ne0,
  488. int64_t ne1,
  489. int64_t ne2);
  490. // offset in bytes
  491. struct ggml_tensor * ggml_view_1d(
  492. struct ggml_context * ctx,
  493. struct ggml_tensor * a,
  494. int64_t ne0,
  495. size_t offset);
  496. struct ggml_tensor * ggml_view_2d(
  497. struct ggml_context * ctx,
  498. struct ggml_tensor * a,
  499. int64_t ne0,
  500. int64_t ne1,
  501. size_t nb1, // row stride in bytes
  502. size_t offset);
  503. struct ggml_tensor * ggml_view_3d(
  504. struct ggml_context * ctx,
  505. struct ggml_tensor * a,
  506. int64_t ne0,
  507. int64_t ne1,
  508. int64_t ne2,
  509. size_t nb1, // row stride in bytes
  510. size_t nb2, // slice stride in bytes
  511. size_t offset);
  512. struct ggml_tensor * ggml_permute(
  513. struct ggml_context * ctx,
  514. struct ggml_tensor * a,
  515. int axis0,
  516. int axis1,
  517. int axis2,
  518. int axis3);
  519. // alias for ggml_permute(ctx, a, 1, 0, 2, 3)
  520. struct ggml_tensor * ggml_transpose(
  521. struct ggml_context * ctx,
  522. struct ggml_tensor * a);
  523. struct ggml_tensor * ggml_get_rows(
  524. struct ggml_context * ctx,
  525. struct ggml_tensor * a,
  526. struct ggml_tensor * b);
  527. // set elements above the diagonal to -INF
  528. // in-place, returns view(a)
  529. struct ggml_tensor * ggml_diag_mask_inf(
  530. struct ggml_context * ctx,
  531. struct ggml_tensor * a,
  532. int n_past);
  533. // in-place, returns view(a)
  534. struct ggml_tensor * ggml_soft_max(
  535. struct ggml_context * ctx,
  536. struct ggml_tensor * a);
  537. // rotary position embedding
  538. // in-place, returns view(a)
  539. // if mode & 1 == 1, skip n_past elements
  540. // if mode & 2 == 1, GPT-NeoX style
  541. // TODO: avoid creating a new tensor every time
  542. struct ggml_tensor * ggml_rope(
  543. struct ggml_context * ctx,
  544. struct ggml_tensor * a,
  545. int n_past,
  546. int n_dims,
  547. int mode);
  548. // padding = 1
  549. // TODO: we don't support extra parameters for now
  550. // that's why we are hard-coding the stride, padding, and dilation
  551. // not great ..
  552. struct ggml_tensor * ggml_conv_1d_1s(
  553. struct ggml_context * ctx,
  554. struct ggml_tensor * a,
  555. struct ggml_tensor * b);
  556. struct ggml_tensor * ggml_conv_1d_2s(
  557. struct ggml_context * ctx,
  558. struct ggml_tensor * a,
  559. struct ggml_tensor * b);
  560. struct ggml_tensor * ggml_flash_attn(
  561. struct ggml_context * ctx,
  562. struct ggml_tensor * q,
  563. struct ggml_tensor * k,
  564. struct ggml_tensor * v,
  565. bool masked);
  566. struct ggml_tensor * ggml_flash_ff(
  567. struct ggml_context * ctx,
  568. struct ggml_tensor * a,
  569. struct ggml_tensor * b0,
  570. struct ggml_tensor * b1,
  571. struct ggml_tensor * c0,
  572. struct ggml_tensor * c1);
  573. // Mapping operations
  574. typedef void (*ggml_unary_op_f32_t)(const int, float *, const float *);
  575. typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *);
  576. struct ggml_tensor * ggml_map_unary_f32(
  577. struct ggml_context * ctx,
  578. struct ggml_tensor * a,
  579. const ggml_unary_op_f32_t fun);
  580. struct ggml_tensor * ggml_map_binary_f32(
  581. struct ggml_context * ctx,
  582. struct ggml_tensor * a,
  583. struct ggml_tensor * b,
  584. const ggml_binary_op_f32_t fun);
  585. //
  586. // automatic differentiation
  587. //
  588. void ggml_set_param(
  589. struct ggml_context * ctx,
  590. struct ggml_tensor * tensor);
  591. void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
  592. struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
  593. struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
  594. void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph);
  595. void ggml_graph_reset (struct ggml_cgraph * cgraph);
  596. // print info and performance information for the graph
  597. void ggml_graph_print(const struct ggml_cgraph * cgraph);
  598. // dump the graph into a file using the dot format
  599. void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename);
  600. //
  601. // optimization
  602. //
  603. // optimization methods
  604. enum ggml_opt_type {
  605. GGML_OPT_ADAM,
  606. GGML_OPT_LBFGS,
  607. };
  608. // linesearch methods
  609. enum ggml_linesearch {
  610. GGML_LINESEARCH_DEFAULT = 1,
  611. GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0,
  612. GGML_LINESEARCH_BACKTRACKING_WOLFE = 1,
  613. GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2,
  614. };
  615. // optimization return values
  616. enum ggml_opt_result {
  617. GGML_OPT_OK = 0,
  618. GGML_OPT_DID_NOT_CONVERGE,
  619. GGML_OPT_NO_CONTEXT,
  620. GGML_OPT_INVALID_WOLFE,
  621. GGML_OPT_FAIL,
  622. GGML_LINESEARCH_FAIL = -128,
  623. GGML_LINESEARCH_MINIMUM_STEP,
  624. GGML_LINESEARCH_MAXIMUM_STEP,
  625. GGML_LINESEARCH_MAXIMUM_ITERATIONS,
  626. GGML_LINESEARCH_INVALID_PARAMETERS,
  627. };
  628. // optimization parameters
  629. //
  630. // see ggml.c (ggml_opt_default_params) for default values
  631. //
  632. struct ggml_opt_params {
  633. enum ggml_opt_type type;
  634. int n_threads;
  635. // delta-based convergence test
  636. //
  637. // if past == 0 - disabled
  638. // if past > 0:
  639. // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|)
  640. //
  641. int past;
  642. float delta;
  643. // maximum number of iterations without improvement
  644. //
  645. // if 0 - disabled
  646. // if > 0:
  647. // assume convergence if no cost improvement in this number of iterations
  648. //
  649. int max_no_improvement;
  650. bool print_forward_graph;
  651. bool print_backward_graph;
  652. // ADAM parameters
  653. struct {
  654. int n_iter;
  655. float alpha; // learning rate
  656. float beta1;
  657. float beta2;
  658. float eps; // epsilon for numerical stability
  659. float eps_f; // epsilon for convergence test
  660. float eps_g; // epsilon for convergence test
  661. } adam;
  662. // LBFGS parameters
  663. struct {
  664. int m; // number of corrections to approximate the inv. Hessian
  665. int n_iter;
  666. int max_linesearch;
  667. float eps; // convergence tolerance
  668. float ftol; // line search tolerance
  669. float wolfe;
  670. float min_step;
  671. float max_step;
  672. enum ggml_linesearch linesearch;
  673. } lbfgs;
  674. };
  675. struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type);
  676. // optimize the function defined by the tensor f
  677. enum ggml_opt_result ggml_opt(
  678. struct ggml_context * ctx,
  679. struct ggml_opt_params params,
  680. struct ggml_tensor * f);
  681. //
  682. // quantization
  683. //
  684. size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist);
  685. size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist);
  686. size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t * hist);
  687. size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t * hist);
  688. size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist);
  689. //
  690. // system info
  691. //
  692. int ggml_cpu_has_avx(void);
  693. int ggml_cpu_has_avx2(void);
  694. int ggml_cpu_has_avx512(void);
  695. int ggml_cpu_has_avx512_vbmi(void);
  696. int ggml_cpu_has_avx512_vnni(void);
  697. int ggml_cpu_has_fma(void);
  698. int ggml_cpu_has_neon(void);
  699. int ggml_cpu_has_arm_fma(void);
  700. int ggml_cpu_has_f16c(void);
  701. int ggml_cpu_has_fp16_va(void);
  702. int ggml_cpu_has_wasm_simd(void);
  703. int ggml_cpu_has_blas(void);
  704. int ggml_cpu_has_cublas(void);
  705. int ggml_cpu_has_sse3(void);
  706. int ggml_cpu_has_vsx(void);
  707. //
  708. // Internal types and functions exposed for tests and benchmarks
  709. //
  710. #ifdef __cplusplus
  711. // restrict not standard in C++
  712. #define GGML_RESTRICT
  713. #else
  714. #define GGML_RESTRICT restrict
  715. #endif
  716. typedef void (*dequantize_row_q_t)(const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k);
  717. typedef void (*quantize_row_q_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k);
  718. typedef void (*vec_dot_q_t)(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y);
  719. typedef struct {
  720. dequantize_row_q_t dequantize_row_q;
  721. quantize_row_q_t quantize_row_q;
  722. quantize_row_q_t quantize_row_q_reference;
  723. quantize_row_q_t quantize_row_q_dot;
  724. vec_dot_q_t vec_dot_q;
  725. } quantize_fns_t;
  726. quantize_fns_t ggml_internal_get_quantize_fn(size_t i);
  727. #ifdef __cplusplus
  728. }
  729. #endif