1
0

llama.h 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928
  1. #ifndef LLAMA_H
  2. #define LLAMA_H
  3. #include "ggml.h"
  4. #include "ggml-backend.h"
  5. #ifdef GGML_USE_CUBLAS
  6. #include "ggml-cuda.h"
  7. #define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
  8. #elif defined(GGML_USE_SYCL)
  9. #include "ggml-sycl.h"
  10. #define LLAMA_MAX_DEVICES GGML_SYCL_MAX_DEVICES
  11. #else
  12. #define LLAMA_MAX_DEVICES 1
  13. #endif // GGML_USE_CUBLAS
  14. #include <stddef.h>
  15. #include <stdint.h>
  16. #include <stdio.h>
  17. #include <stdbool.h>
  18. #ifdef LLAMA_SHARED
  19. # if defined(_WIN32) && !defined(__MINGW32__)
  20. # ifdef LLAMA_BUILD
  21. # define LLAMA_API __declspec(dllexport)
  22. # else
  23. # define LLAMA_API __declspec(dllimport)
  24. # endif
  25. # else
  26. # define LLAMA_API __attribute__ ((visibility ("default")))
  27. # endif
  28. #else
  29. # define LLAMA_API
  30. #endif
  31. #ifdef __GNUC__
  32. # define DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
  33. #elif defined(_MSC_VER)
  34. # define DEPRECATED(func, hint) __declspec(deprecated(hint)) func
  35. #else
  36. # define DEPRECATED(func, hint) func
  37. #endif
  38. #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
  39. #define LLAMA_MAX_RNG_STATE (64*1024)
  40. #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
  41. #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
  42. #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
  43. #define LLAMA_SESSION_VERSION 4
  44. #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
  45. defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
  46. // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
  47. #define LLAMA_SUPPORTS_GPU_OFFLOAD
  48. #endif
  49. #ifdef __cplusplus
  50. extern "C" {
  51. #endif
  52. //
  53. // C interface
  54. //
  55. // TODO: show sample usage
  56. //
  57. struct llama_model;
  58. struct llama_context;
  59. typedef int32_t llama_pos;
  60. typedef int32_t llama_token;
  61. typedef int32_t llama_seq_id;
  62. enum llama_vocab_type {
  63. LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
  64. LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
  65. };
  66. enum llama_token_type {
  67. LLAMA_TOKEN_TYPE_UNDEFINED = 0,
  68. LLAMA_TOKEN_TYPE_NORMAL = 1,
  69. LLAMA_TOKEN_TYPE_UNKNOWN = 2,
  70. LLAMA_TOKEN_TYPE_CONTROL = 3,
  71. LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
  72. LLAMA_TOKEN_TYPE_UNUSED = 5,
  73. LLAMA_TOKEN_TYPE_BYTE = 6,
  74. };
  75. // model file types
  76. enum llama_ftype {
  77. LLAMA_FTYPE_ALL_F32 = 0,
  78. LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
  79. LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
  80. LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
  81. LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
  82. // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
  83. // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
  84. LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
  85. LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
  86. LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
  87. LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
  88. LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors
  89. LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors
  90. LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors
  91. LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors
  92. LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors
  93. LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors
  94. LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors
  95. LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors
  96. LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19, // except 1d tensors
  97. LLAMA_FTYPE_MOSTLY_IQ2_XS = 20, // except 1d tensors
  98. LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors
  99. LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors
  100. LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors
  101. LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
  102. };
  103. enum llama_rope_scaling_type {
  104. LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
  105. LLAMA_ROPE_SCALING_NONE = 0,
  106. LLAMA_ROPE_SCALING_LINEAR = 1,
  107. LLAMA_ROPE_SCALING_YARN = 2,
  108. LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
  109. };
  110. enum llama_split_mode {
  111. LLAMA_SPLIT_NONE = 0, // single GPU
  112. LLAMA_SPLIT_LAYER = 1, // split layers and KV across GPUs
  113. LLAMA_SPLIT_ROW = 2, // split rows across GPUs
  114. };
  115. typedef struct llama_token_data {
  116. llama_token id; // token id
  117. float logit; // log-odds of the token
  118. float p; // probability of the token
  119. } llama_token_data;
  120. typedef struct llama_token_data_array {
  121. llama_token_data * data;
  122. size_t size;
  123. bool sorted;
  124. } llama_token_data_array;
  125. typedef bool (*llama_progress_callback)(float progress, void *ctx);
  126. // Input data for llama_decode
  127. // A llama_batch object can contain input about one or many sequences
  128. // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
  129. //
  130. // - token : the token ids of the input (used when embd is NULL)
  131. // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
  132. // - pos : the positions of the respective token in the sequence
  133. // - seq_id : the sequence to which the respective token belongs
  134. // - logits : if zero, the logits for the respective token will not be output
  135. //
  136. typedef struct llama_batch {
  137. int32_t n_tokens;
  138. llama_token * token;
  139. float * embd;
  140. llama_pos * pos;
  141. int32_t * n_seq_id;
  142. llama_seq_id ** seq_id;
  143. int8_t * logits;
  144. // NOTE: helpers for smooth API transition - can be deprecated in the future
  145. // for future-proof code, use the above fields instead and ignore everything below
  146. //
  147. // pos[i] = all_pos_0 + i*all_pos_1
  148. //
  149. llama_pos all_pos_0; // used if pos == NULL
  150. llama_pos all_pos_1; // used if pos == NULL
  151. llama_seq_id all_seq_id; // used if seq_id == NULL
  152. } llama_batch;
  153. enum llama_model_kv_override_type {
  154. LLAMA_KV_OVERRIDE_INT,
  155. LLAMA_KV_OVERRIDE_FLOAT,
  156. LLAMA_KV_OVERRIDE_BOOL,
  157. };
  158. struct llama_model_kv_override {
  159. char key[128];
  160. enum llama_model_kv_override_type tag;
  161. union {
  162. int64_t int_value;
  163. double float_value;
  164. bool bool_value;
  165. };
  166. };
  167. struct llama_model_params {
  168. int32_t n_gpu_layers; // number of layers to store in VRAM
  169. enum llama_split_mode split_mode; // how to split the model across multiple GPUs
  170. // main_gpu interpretation depends on split_mode:
  171. // LLAMA_SPLIT_NONE: the GPU that is used for the entire model
  172. // LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results
  173. // LLAMA_SPLIT_LAYER: ignored
  174. int32_t main_gpu;
  175. // proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
  176. const float * tensor_split;
  177. // Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
  178. // If the provided progress_callback returns true, model loading continues.
  179. // If it returns false, model loading is immediately aborted.
  180. llama_progress_callback progress_callback;
  181. // context pointer passed to the progress callback
  182. void * progress_callback_user_data;
  183. // override key-value pairs of the model meta data
  184. const struct llama_model_kv_override * kv_overrides;
  185. // Keep the booleans together to avoid misalignment during copy-by-value.
  186. bool vocab_only; // only load the vocabulary, no weights
  187. bool use_mmap; // use mmap if possible
  188. bool use_mlock; // force system to keep model in RAM
  189. };
  190. struct llama_context_params {
  191. uint32_t seed; // RNG seed, -1 for random
  192. uint32_t n_ctx; // text context, 0 = from model
  193. uint32_t n_batch; // prompt processing maximum batch size
  194. uint32_t n_threads; // number of threads to use for generation
  195. uint32_t n_threads_batch; // number of threads to use for batch processing
  196. int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
  197. // ref: https://github.com/ggerganov/llama.cpp/pull/2054
  198. float rope_freq_base; // RoPE base frequency, 0 = from model
  199. float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
  200. float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
  201. float yarn_attn_factor; // YaRN magnitude scaling factor
  202. float yarn_beta_fast; // YaRN low correction dim
  203. float yarn_beta_slow; // YaRN high correction dim
  204. uint32_t yarn_orig_ctx; // YaRN original context size
  205. ggml_backend_sched_eval_callback cb_eval;
  206. void * cb_eval_user_data;
  207. enum ggml_type type_k; // data type for K cache
  208. enum ggml_type type_v; // data type for V cache
  209. // Keep the booleans together to avoid misalignment during copy-by-value.
  210. bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
  211. bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
  212. bool embedding; // embedding mode only
  213. bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
  214. };
  215. // model quantization parameters
  216. typedef struct llama_model_quantize_params {
  217. int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
  218. enum llama_ftype ftype; // quantize to this llama_ftype
  219. bool allow_requantize; // allow quantizing non-f32/f16 tensors
  220. bool quantize_output_tensor; // quantize output.weight
  221. bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
  222. bool pure; // disable k-quant mixtures and quantize all tensors to the same type
  223. void * imatrix; // pointer to importance matrix data
  224. } llama_model_quantize_params;
  225. // grammar types
  226. struct llama_grammar;
  227. // grammar element type
  228. enum llama_gretype {
  229. // end of rule definition
  230. LLAMA_GRETYPE_END = 0,
  231. // start of alternate definition for rule
  232. LLAMA_GRETYPE_ALT = 1,
  233. // non-terminal element: reference to rule
  234. LLAMA_GRETYPE_RULE_REF = 2,
  235. // terminal element: character (code point)
  236. LLAMA_GRETYPE_CHAR = 3,
  237. // inverse char(s) ([^a], [^a-b] [^abc])
  238. LLAMA_GRETYPE_CHAR_NOT = 4,
  239. // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
  240. // be an inclusive range ([a-z])
  241. LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
  242. // modifies a preceding LLAMA_GRETYPE_CHAR or
  243. // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
  244. LLAMA_GRETYPE_CHAR_ALT = 6,
  245. };
  246. typedef struct llama_grammar_element {
  247. enum llama_gretype type;
  248. uint32_t value; // Unicode code point or rule ID
  249. } llama_grammar_element;
  250. // performance timing information
  251. struct llama_timings {
  252. double t_start_ms;
  253. double t_end_ms;
  254. double t_load_ms;
  255. double t_sample_ms;
  256. double t_p_eval_ms;
  257. double t_eval_ms;
  258. int32_t n_sample;
  259. int32_t n_p_eval;
  260. int32_t n_eval;
  261. };
  262. // Helpers for getting default parameters
  263. LLAMA_API struct llama_model_params llama_model_default_params(void);
  264. LLAMA_API struct llama_context_params llama_context_default_params(void);
  265. LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
  266. // Initialize the llama + ggml backend
  267. // If numa is true, use NUMA optimizations
  268. // Call once at the start of the program
  269. LLAMA_API void llama_backend_init(bool numa);
  270. // Call once at the end of the program - currently only used for MPI
  271. LLAMA_API void llama_backend_free(void);
  272. LLAMA_API struct llama_model * llama_load_model_from_file(
  273. const char * path_model,
  274. struct llama_model_params params);
  275. LLAMA_API void llama_free_model(struct llama_model * model);
  276. LLAMA_API struct llama_context * llama_new_context_with_model(
  277. struct llama_model * model,
  278. struct llama_context_params params);
  279. // Frees all allocated memory
  280. LLAMA_API void llama_free(struct llama_context * ctx);
  281. LLAMA_API int64_t llama_time_us(void);
  282. LLAMA_API int32_t llama_max_devices(void);
  283. LLAMA_API bool llama_mmap_supported (void);
  284. LLAMA_API bool llama_mlock_supported(void);
  285. LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
  286. LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
  287. LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
  288. LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
  289. LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
  290. LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
  291. LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
  292. // Get the model's RoPE frequency scaling factor
  293. LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
  294. // Functions to access the model's GGUF metadata scalar values
  295. // - The functions return the length of the string on success, or -1 on failure
  296. // - The output string is always null-terminated and cleared on failure
  297. // - GGUF array values are not supported by these functions
  298. // Get metadata value as a string by key name
  299. LLAMA_API int32_t llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
  300. // Get the number of metadata key/value pairs
  301. LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
  302. // Get metadata key name by index
  303. LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
  304. // Get metadata value as a string by index
  305. LLAMA_API int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
  306. // Get a string describing the model type
  307. LLAMA_API int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
  308. // Returns the total size of all the tensors in the model in bytes
  309. LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
  310. // Returns the total number of parameters in the model
  311. LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
  312. // Get a llama model tensor
  313. LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
  314. // Returns 0 on success
  315. LLAMA_API uint32_t llama_model_quantize(
  316. const char * fname_inp,
  317. const char * fname_out,
  318. const llama_model_quantize_params * params);
  319. // Apply a LoRA adapter to a loaded model
  320. // path_base_model is the path to a higher quality model to use as a base for
  321. // the layers modified by the adapter. Can be NULL to use the current loaded model.
  322. // The model needs to be reloaded before applying a new adapter, otherwise the adapter
  323. // will be applied on top of the previous one
  324. // Returns 0 on success
  325. LLAMA_API DEPRECATED(int32_t llama_apply_lora_from_file(
  326. struct llama_context * ctx,
  327. const char * path_lora,
  328. float scale,
  329. const char * path_base_model,
  330. int32_t n_threads),
  331. "use llama_model_apply_lora_from_file instead");
  332. LLAMA_API int32_t llama_model_apply_lora_from_file(
  333. const struct llama_model * model,
  334. const char * path_lora,
  335. float scale,
  336. const char * path_base_model,
  337. int32_t n_threads);
  338. //
  339. // KV cache
  340. //
  341. // Information associated with an individual cell in the KV cache view.
  342. struct llama_kv_cache_view_cell {
  343. // The position for this cell. Takes KV cache shifts into account.
  344. // May be negative if the cell is not populated.
  345. llama_pos pos;
  346. };
  347. // An updateable view of the KV cache.
  348. struct llama_kv_cache_view {
  349. // Number of KV cache cells. This will be the same as the context size.
  350. int32_t n_cells;
  351. // Maximum number of sequences that can exist in a cell. It's not an error
  352. // if there are more sequences in a cell than this value, however they will
  353. // not be visible in the view cells_sequences.
  354. int32_t n_max_seq;
  355. // Number of tokens in the cache. For example, if there are two populated
  356. // cells, the first with 1 sequence id in it and the second with 2 sequence
  357. // ids then you'll have 3 tokens.
  358. int32_t token_count;
  359. // Number of populated cache cells.
  360. int32_t used_cells;
  361. // Maximum contiguous empty slots in the cache.
  362. int32_t max_contiguous;
  363. // Index to the start of the max_contiguous slot range. Can be negative
  364. // when cache is full.
  365. int32_t max_contiguous_idx;
  366. // Information for an individual cell.
  367. struct llama_kv_cache_view_cell * cells;
  368. // The sequences for each cell. There will be n_max_seq items per cell.
  369. llama_seq_id * cells_sequences;
  370. };
  371. // Create an empty KV cache view. (use only for debugging purposes)
  372. LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
  373. // Free a KV cache view. (use only for debugging purposes)
  374. LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
  375. // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
  376. LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
  377. // Returns the number of tokens in the KV cache (slow, use only for debug)
  378. // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
  379. LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx);
  380. // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
  381. LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
  382. // Clear the KV cache
  383. LLAMA_API void llama_kv_cache_clear(
  384. struct llama_context * ctx);
  385. // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
  386. // seq_id < 0 : match any sequence
  387. // p0 < 0 : [0, p1]
  388. // p1 < 0 : [p0, inf)
  389. LLAMA_API void llama_kv_cache_seq_rm(
  390. struct llama_context * ctx,
  391. llama_seq_id seq_id,
  392. llama_pos p0,
  393. llama_pos p1);
  394. // Copy all tokens that belong to the specified sequence to another sequence
  395. // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
  396. // p0 < 0 : [0, p1]
  397. // p1 < 0 : [p0, inf)
  398. LLAMA_API void llama_kv_cache_seq_cp(
  399. struct llama_context * ctx,
  400. llama_seq_id seq_id_src,
  401. llama_seq_id seq_id_dst,
  402. llama_pos p0,
  403. llama_pos p1);
  404. // Removes all tokens that do not belong to the specified sequence
  405. LLAMA_API void llama_kv_cache_seq_keep(
  406. struct llama_context * ctx,
  407. llama_seq_id seq_id);
  408. // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
  409. // If the KV cache is RoPEd, the KV data is updated accordingly
  410. // p0 < 0 : [0, p1]
  411. // p1 < 0 : [p0, inf)
  412. LLAMA_API void llama_kv_cache_seq_shift(
  413. struct llama_context * ctx,
  414. llama_seq_id seq_id,
  415. llama_pos p0,
  416. llama_pos p1,
  417. llama_pos delta);
  418. // Integer division of the positions by factor of `d > 1`
  419. // If the KV cache is RoPEd, the KV data is updated accordingly
  420. // p0 < 0 : [0, p1]
  421. // p1 < 0 : [p0, inf)
  422. LLAMA_API void llama_kv_cache_seq_div(
  423. struct llama_context * ctx,
  424. llama_seq_id seq_id,
  425. llama_pos p0,
  426. llama_pos p1,
  427. int d);
  428. //
  429. // State / sessions
  430. //
  431. // Returns the maximum size in bytes of the state (rng, logits, embedding
  432. // and kv_cache) - will often be smaller after compacting tokens
  433. LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
  434. // Copies the state to the specified destination address.
  435. // Destination needs to have allocated enough memory.
  436. // Returns the number of bytes copied
  437. LLAMA_API size_t llama_copy_state_data(
  438. struct llama_context * ctx,
  439. uint8_t * dst);
  440. // Set the state reading from the specified address
  441. // Returns the number of bytes read
  442. LLAMA_API size_t llama_set_state_data(
  443. struct llama_context * ctx,
  444. uint8_t * src);
  445. // Save/load session file
  446. LLAMA_API bool llama_load_session_file(
  447. struct llama_context * ctx,
  448. const char * path_session,
  449. llama_token * tokens_out,
  450. size_t n_token_capacity,
  451. size_t * n_token_count_out);
  452. LLAMA_API bool llama_save_session_file(
  453. struct llama_context * ctx,
  454. const char * path_session,
  455. const llama_token * tokens,
  456. size_t n_token_count);
  457. //
  458. // Decoding
  459. //
  460. // Run the llama inference to obtain the logits and probabilities for the next token(s).
  461. // tokens + n_tokens is the provided batch of new tokens to process
  462. // n_past is the number of tokens to use from previous eval calls
  463. // Returns 0 on success
  464. // DEPRECATED: use llama_decode() instead
  465. LLAMA_API DEPRECATED(int llama_eval(
  466. struct llama_context * ctx,
  467. llama_token * tokens,
  468. int32_t n_tokens,
  469. int32_t n_past),
  470. "use llama_decode() instead");
  471. // Same as llama_eval, but use float matrix input directly.
  472. // DEPRECATED: use llama_decode() instead
  473. LLAMA_API DEPRECATED(int llama_eval_embd(
  474. struct llama_context * ctx,
  475. float * embd,
  476. int32_t n_tokens,
  477. int32_t n_past),
  478. "use llama_decode() instead");
  479. // Return batch for single sequence of tokens starting at pos_0
  480. //
  481. // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
  482. //
  483. LLAMA_API struct llama_batch llama_batch_get_one(
  484. llama_token * tokens,
  485. int32_t n_tokens,
  486. llama_pos pos_0,
  487. llama_seq_id seq_id);
  488. // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
  489. // Each token can be assigned up to n_seq_max sequence ids
  490. // The batch has to be freed with llama_batch_free()
  491. // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
  492. // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
  493. // The rest of the llama_batch members are allocated with size n_tokens
  494. // All members are left uninitialized
  495. LLAMA_API struct llama_batch llama_batch_init(
  496. int32_t n_tokens,
  497. int32_t embd,
  498. int32_t n_seq_max);
  499. // Frees a batch of tokens allocated with llama_batch_init()
  500. LLAMA_API void llama_batch_free(struct llama_batch batch);
  501. // Positive return values does not mean a fatal error, but rather a warning.
  502. // 0 - success
  503. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  504. // < 0 - error
  505. LLAMA_API int32_t llama_decode(
  506. struct llama_context * ctx,
  507. struct llama_batch batch);
  508. // Set the number of threads used for decoding
  509. // n_threads is the number of threads used for generation (single token)
  510. // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
  511. LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
  512. // Token logits obtained from the last call to llama_eval()
  513. // The logits for the last token are stored in the last row
  514. // Logits for which llama_batch.logits[i] == 0 are undefined
  515. // Rows: n_tokens provided with llama_batch
  516. // Cols: n_vocab
  517. LLAMA_API float * llama_get_logits(struct llama_context * ctx);
  518. // Logits for the ith token. Equivalent to:
  519. // llama_get_logits(ctx) + i*n_vocab
  520. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
  521. // Get the embeddings for the input
  522. // shape: [n_embd] (1-dimensional)
  523. LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
  524. //
  525. // Vocab
  526. //
  527. LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
  528. LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
  529. LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
  530. // Special tokens
  531. LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
  532. LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
  533. LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
  534. // Returns -1 if unknown, 1 for true or 0 for false.
  535. LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
  536. // Returns -1 if unknown, 1 for true or 0 for false.
  537. LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
  538. // codellama infill tokens
  539. LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
  540. LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
  541. LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
  542. LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
  543. //
  544. // Tokenization
  545. //
  546. /// @details Convert the provided text into tokens.
  547. /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
  548. /// @return Returns the number of tokens on success, no more than n_max_tokens
  549. /// @return Returns a negative number on failure - the number of tokens that would have been returned
  550. /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
  551. /// Does not insert a leading space.
  552. LLAMA_API int32_t llama_tokenize(
  553. const struct llama_model * model,
  554. const char * text,
  555. int32_t text_len,
  556. llama_token * tokens,
  557. int32_t n_max_tokens,
  558. bool add_bos,
  559. bool special);
  560. // Token Id -> Piece.
  561. // Uses the vocabulary in the provided context.
  562. // Does not write null terminator to the buffer.
  563. // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
  564. LLAMA_API int32_t llama_token_to_piece(
  565. const struct llama_model * model,
  566. llama_token token,
  567. char * buf,
  568. int32_t length);
  569. //
  570. // Grammar
  571. //
  572. LLAMA_API struct llama_grammar * llama_grammar_init(
  573. const llama_grammar_element ** rules,
  574. size_t n_rules,
  575. size_t start_rule_index);
  576. LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
  577. LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
  578. //
  579. // Sampling functions
  580. //
  581. // Sets the current rng seed.
  582. LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
  583. /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  584. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  585. LLAMA_API void llama_sample_repetition_penalties(
  586. struct llama_context * ctx,
  587. llama_token_data_array * candidates,
  588. const llama_token * last_tokens,
  589. size_t penalty_last_n,
  590. float penalty_repeat,
  591. float penalty_freq,
  592. float penalty_present);
  593. /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
  594. /// @param logits Logits extracted from the original generation context.
  595. /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
  596. /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
  597. LLAMA_API void llama_sample_apply_guidance(
  598. struct llama_context * ctx,
  599. float * logits,
  600. float * logits_guidance,
  601. float scale);
  602. LLAMA_API DEPRECATED(void llama_sample_classifier_free_guidance(
  603. struct llama_context * ctx,
  604. llama_token_data_array * candidates,
  605. struct llama_context * guidance_ctx,
  606. float scale),
  607. "use llama_sample_apply_guidance() instead");
  608. /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  609. LLAMA_API void llama_sample_softmax(
  610. struct llama_context * ctx,
  611. llama_token_data_array * candidates);
  612. /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  613. LLAMA_API void llama_sample_top_k(
  614. struct llama_context * ctx,
  615. llama_token_data_array * candidates,
  616. int32_t k,
  617. size_t min_keep);
  618. /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  619. LLAMA_API void llama_sample_top_p(
  620. struct llama_context * ctx,
  621. llama_token_data_array * candidates,
  622. float p,
  623. size_t min_keep);
  624. /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
  625. LLAMA_API void llama_sample_min_p(
  626. struct llama_context * ctx,
  627. llama_token_data_array * candidates,
  628. float p,
  629. size_t min_keep);
  630. /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  631. LLAMA_API void llama_sample_tail_free(
  632. struct llama_context * ctx,
  633. llama_token_data_array * candidates,
  634. float z,
  635. size_t min_keep);
  636. /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  637. LLAMA_API void llama_sample_typical(
  638. struct llama_context * ctx,
  639. llama_token_data_array * candidates,
  640. float p,
  641. size_t min_keep);
  642. /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
  643. LLAMA_API void llama_sample_entropy(
  644. struct llama_context * ctx,
  645. llama_token_data_array * candidates_p,
  646. float min_temp,
  647. float max_temp,
  648. float exponent_val);
  649. LLAMA_API void llama_sample_temp(
  650. struct llama_context * ctx,
  651. llama_token_data_array * candidates,
  652. float temp);
  653. LLAMA_API DEPRECATED(void llama_sample_temperature(
  654. struct llama_context * ctx,
  655. llama_token_data_array * candidates,
  656. float temp),
  657. "use llama_sample_temp instead");
  658. /// @details Apply constraints from grammar
  659. LLAMA_API void llama_sample_grammar(
  660. struct llama_context * ctx,
  661. llama_token_data_array * candidates,
  662. const struct llama_grammar * grammar);
  663. /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  664. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
  665. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
  666. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
  667. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
  668. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
  669. LLAMA_API llama_token llama_sample_token_mirostat(
  670. struct llama_context * ctx,
  671. llama_token_data_array * candidates,
  672. float tau,
  673. float eta,
  674. int32_t m,
  675. float * mu);
  676. /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  677. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
  678. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
  679. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
  680. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
  681. LLAMA_API llama_token llama_sample_token_mirostat_v2(
  682. struct llama_context * ctx,
  683. llama_token_data_array * candidates,
  684. float tau,
  685. float eta,
  686. float * mu);
  687. /// @details Selects the token with the highest probability.
  688. /// Does not compute the token probabilities. Use llama_sample_softmax() instead.
  689. LLAMA_API llama_token llama_sample_token_greedy(
  690. struct llama_context * ctx,
  691. llama_token_data_array * candidates);
  692. /// @details Randomly selects a token from the candidates based on their probabilities.
  693. LLAMA_API llama_token llama_sample_token(
  694. struct llama_context * ctx,
  695. llama_token_data_array * candidates);
  696. /// @details Accepts the sampled token into the grammar
  697. LLAMA_API void llama_grammar_accept_token(
  698. struct llama_context * ctx,
  699. struct llama_grammar * grammar,
  700. llama_token token);
  701. //
  702. // Beam search
  703. //
  704. struct llama_beam_view {
  705. const llama_token * tokens;
  706. size_t n_tokens;
  707. float p; // Cumulative beam probability (renormalized relative to all beams)
  708. bool eob; // Callback should set this to true when a beam is at end-of-beam.
  709. };
  710. // Passed to beam_search_callback function.
  711. // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
  712. // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
  713. // These pointers are valid only during the synchronous callback, so should not be saved.
  714. struct llama_beams_state {
  715. struct llama_beam_view * beam_views;
  716. size_t n_beams; // Number of elements in beam_views[].
  717. size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
  718. bool last_call; // True iff this is the last callback invocation.
  719. };
  720. // Type of pointer to the beam_search_callback function.
  721. // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
  722. // passed back to beam_search_callback. This avoids having to use global variables in the callback.
  723. typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
  724. /// @details Deterministically returns entire sentence constructed by a beam search.
  725. /// @param ctx Pointer to the llama_context.
  726. /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
  727. /// @param callback_data A pointer that is simply passed back to callback.
  728. /// @param n_beams Number of beams to use.
  729. /// @param n_past Number of tokens already evaluated.
  730. /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
  731. LLAMA_API void llama_beam_search(
  732. struct llama_context * ctx,
  733. llama_beam_search_callback_fn_t callback,
  734. void * callback_data,
  735. size_t n_beams,
  736. int32_t n_past,
  737. int32_t n_predict);
  738. // Performance information
  739. LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
  740. LLAMA_API void llama_print_timings(struct llama_context * ctx);
  741. LLAMA_API void llama_reset_timings(struct llama_context * ctx);
  742. // Print system information
  743. LLAMA_API const char * llama_print_system_info(void);
  744. // Set callback for all future logging events.
  745. // If this is not called, or NULL is supplied, everything is output on stderr.
  746. LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
  747. LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
  748. #ifdef __cplusplus
  749. }
  750. #endif
  751. // Internal API to be implemented by llama.cpp and used by tests/benchmarks only
  752. #ifdef LLAMA_API_INTERNAL
  753. #include <vector>
  754. #include <string>
  755. struct ggml_tensor;
  756. const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
  757. struct llama_context * ctx
  758. );
  759. #endif // LLAMA_API_INTERNAL
  760. #endif // LLAMA_H