llama.h 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870
  1. #ifndef LLAMA_H
  2. #define LLAMA_H
  3. #include "ggml.h"
  4. #ifdef GGML_USE_CUBLAS
  5. #include "ggml-cuda.h"
  6. #define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
  7. #else
  8. #define LLAMA_MAX_DEVICES 1
  9. #endif // GGML_USE_CUBLAS
  10. #include <stddef.h>
  11. #include <stdint.h>
  12. #include <stdio.h>
  13. #include <stdbool.h>
  14. #ifdef LLAMA_SHARED
  15. # if defined(_WIN32) && !defined(__MINGW32__)
  16. # ifdef LLAMA_BUILD
  17. # define LLAMA_API __declspec(dllexport)
  18. # else
  19. # define LLAMA_API __declspec(dllimport)
  20. # endif
  21. # else
  22. # define LLAMA_API __attribute__ ((visibility ("default")))
  23. # endif
  24. #else
  25. # define LLAMA_API
  26. #endif
  27. #ifdef __GNUC__
  28. # define DEPRECATED(func, hint) func __attribute__((deprecated(hint)))
  29. #elif defined(_MSC_VER)
  30. # define DEPRECATED(func, hint) __declspec(deprecated(hint)) func
  31. #else
  32. # define DEPRECATED(func, hint) func
  33. #endif
  34. #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
  35. #define LLAMA_MAX_RNG_STATE (64*1024)
  36. #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
  37. #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
  38. #define LLAMA_SESSION_VERSION 3
  39. #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
  40. // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
  41. #define LLAMA_SUPPORTS_GPU_OFFLOAD
  42. #endif
  43. #ifdef __cplusplus
  44. extern "C" {
  45. #endif
  46. //
  47. // C interface
  48. //
  49. // TODO: show sample usage
  50. //
  51. struct llama_model;
  52. struct llama_context;
  53. typedef int32_t llama_pos;
  54. typedef int32_t llama_token;
  55. typedef int32_t llama_seq_id;
  56. enum llama_vocab_type {
  57. LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
  58. LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
  59. };
  60. enum llama_token_type {
  61. LLAMA_TOKEN_TYPE_UNDEFINED = 0,
  62. LLAMA_TOKEN_TYPE_NORMAL = 1,
  63. LLAMA_TOKEN_TYPE_UNKNOWN = 2,
  64. LLAMA_TOKEN_TYPE_CONTROL = 3,
  65. LLAMA_TOKEN_TYPE_USER_DEFINED = 4,
  66. LLAMA_TOKEN_TYPE_UNUSED = 5,
  67. LLAMA_TOKEN_TYPE_BYTE = 6,
  68. };
  69. // model file types
  70. enum llama_ftype {
  71. LLAMA_FTYPE_ALL_F32 = 0,
  72. LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
  73. LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
  74. LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
  75. LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
  76. // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
  77. // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
  78. LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
  79. LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
  80. LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
  81. LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors
  82. LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors
  83. LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors
  84. LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors
  85. LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors
  86. LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors
  87. LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors
  88. LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors
  89. LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors
  90. LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
  91. };
  92. enum llama_rope_scaling_type {
  93. LLAMA_ROPE_SCALING_UNSPECIFIED = -1,
  94. LLAMA_ROPE_SCALING_NONE = 0,
  95. LLAMA_ROPE_SCALING_LINEAR = 1,
  96. LLAMA_ROPE_SCALING_YARN = 2,
  97. LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
  98. };
  99. typedef struct llama_token_data {
  100. llama_token id; // token id
  101. float logit; // log-odds of the token
  102. float p; // probability of the token
  103. } llama_token_data;
  104. typedef struct llama_token_data_array {
  105. llama_token_data * data;
  106. size_t size;
  107. bool sorted;
  108. } llama_token_data_array;
  109. typedef void (*llama_progress_callback)(float progress, void *ctx);
  110. // Input data for llama_decode
  111. // A llama_batch object can contain input about one or many sequences
  112. // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
  113. //
  114. // - token : the token ids of the input (used when embd is NULL)
  115. // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
  116. // - pos : the positions of the respective token in the sequence
  117. // - seq_id : the sequence to which the respective token belongs
  118. // - logits : if zero, the logits for the respective token will not be output
  119. //
  120. typedef struct llama_batch {
  121. int32_t n_tokens;
  122. llama_token * token;
  123. float * embd;
  124. llama_pos * pos;
  125. int32_t * n_seq_id;
  126. llama_seq_id ** seq_id;
  127. int8_t * logits;
  128. // NOTE: helpers for smooth API transition - can be deprecated in the future
  129. // for future-proof code, use the above fields instead and ignore everything below
  130. //
  131. // pos[i] = all_pos_0 + i*all_pos_1
  132. //
  133. llama_pos all_pos_0; // used if pos == NULL
  134. llama_pos all_pos_1; // used if pos == NULL
  135. llama_seq_id all_seq_id; // used if seq_id == NULL
  136. } llama_batch;
  137. enum llama_model_kv_override_type {
  138. LLAMA_KV_OVERRIDE_INT,
  139. LLAMA_KV_OVERRIDE_FLOAT,
  140. LLAMA_KV_OVERRIDE_BOOL,
  141. };
  142. struct llama_model_kv_override {
  143. char key[128];
  144. enum llama_model_kv_override_type tag;
  145. union {
  146. int64_t int_value;
  147. double float_value;
  148. bool bool_value;
  149. };
  150. };
  151. struct llama_model_params {
  152. int32_t n_gpu_layers; // number of layers to store in VRAM
  153. int32_t main_gpu; // the GPU that is used for scratch and small tensors
  154. const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
  155. // called with a progress value between 0 and 1, pass NULL to disable
  156. llama_progress_callback progress_callback;
  157. // context pointer passed to the progress callback
  158. void * progress_callback_user_data;
  159. // override key-value pairs of the model meta data
  160. const struct llama_model_kv_override * kv_overrides;
  161. // Keep the booleans together to avoid misalignment during copy-by-value.
  162. bool vocab_only; // only load the vocabulary, no weights
  163. bool use_mmap; // use mmap if possible
  164. bool use_mlock; // force system to keep model in RAM
  165. };
  166. struct llama_context_params {
  167. uint32_t seed; // RNG seed, -1 for random
  168. uint32_t n_ctx; // text context, 0 = from model
  169. uint32_t n_batch; // prompt processing maximum batch size
  170. uint32_t n_threads; // number of threads to use for generation
  171. uint32_t n_threads_batch; // number of threads to use for batch processing
  172. int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
  173. // ref: https://github.com/ggerganov/llama.cpp/pull/2054
  174. float rope_freq_base; // RoPE base frequency, 0 = from model
  175. float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
  176. float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
  177. float yarn_attn_factor; // YaRN magnitude scaling factor
  178. float yarn_beta_fast; // YaRN low correction dim
  179. float yarn_beta_slow; // YaRN high correction dim
  180. uint32_t yarn_orig_ctx; // YaRN original context size
  181. enum ggml_type type_k; // data type for K cache
  182. enum ggml_type type_v; // data type for V cache
  183. // Keep the booleans together to avoid misalignment during copy-by-value.
  184. bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
  185. bool logits_all; // the llama_eval() call computes all logits, not just the last one
  186. bool embedding; // embedding mode only
  187. bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
  188. };
  189. // model quantization parameters
  190. typedef struct llama_model_quantize_params {
  191. int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
  192. enum llama_ftype ftype; // quantize to this llama_ftype
  193. bool allow_requantize; // allow quantizing non-f32/f16 tensors
  194. bool quantize_output_tensor; // quantize output.weight
  195. bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored
  196. bool pure; // disable k-quant mixtures and quantize all tensors to the same type
  197. } llama_model_quantize_params;
  198. // grammar types
  199. struct llama_grammar;
  200. // grammar element type
  201. enum llama_gretype {
  202. // end of rule definition
  203. LLAMA_GRETYPE_END = 0,
  204. // start of alternate definition for rule
  205. LLAMA_GRETYPE_ALT = 1,
  206. // non-terminal element: reference to rule
  207. LLAMA_GRETYPE_RULE_REF = 2,
  208. // terminal element: character (code point)
  209. LLAMA_GRETYPE_CHAR = 3,
  210. // inverse char(s) ([^a], [^a-b] [^abc])
  211. LLAMA_GRETYPE_CHAR_NOT = 4,
  212. // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to
  213. // be an inclusive range ([a-z])
  214. LLAMA_GRETYPE_CHAR_RNG_UPPER = 5,
  215. // modifies a preceding LLAMA_GRETYPE_CHAR or
  216. // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA])
  217. LLAMA_GRETYPE_CHAR_ALT = 6,
  218. };
  219. typedef struct llama_grammar_element {
  220. enum llama_gretype type;
  221. uint32_t value; // Unicode code point or rule ID
  222. } llama_grammar_element;
  223. // performance timing information
  224. struct llama_timings {
  225. double t_start_ms;
  226. double t_end_ms;
  227. double t_load_ms;
  228. double t_sample_ms;
  229. double t_p_eval_ms;
  230. double t_eval_ms;
  231. int32_t n_sample;
  232. int32_t n_p_eval;
  233. int32_t n_eval;
  234. };
  235. // Helpers for getting default parameters
  236. LLAMA_API struct llama_model_params llama_model_default_params(void);
  237. LLAMA_API struct llama_context_params llama_context_default_params(void);
  238. LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);
  239. // Initialize the llama + ggml backend
  240. // If numa is true, use NUMA optimizations
  241. // Call once at the start of the program
  242. LLAMA_API void llama_backend_init(bool numa);
  243. // Call once at the end of the program - currently only used for MPI
  244. LLAMA_API void llama_backend_free(void);
  245. LLAMA_API struct llama_model * llama_load_model_from_file(
  246. const char * path_model,
  247. struct llama_model_params params);
  248. LLAMA_API void llama_free_model(struct llama_model * model);
  249. LLAMA_API struct llama_context * llama_new_context_with_model(
  250. struct llama_model * model,
  251. struct llama_context_params params);
  252. // Frees all allocated memory
  253. LLAMA_API void llama_free(struct llama_context * ctx);
  254. LLAMA_API int64_t llama_time_us(void);
  255. LLAMA_API int llama_max_devices (void);
  256. LLAMA_API bool llama_mmap_supported (void);
  257. LLAMA_API bool llama_mlock_supported(void);
  258. LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
  259. LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
  260. LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
  261. LLAMA_API int llama_n_vocab (const struct llama_model * model);
  262. LLAMA_API int llama_n_ctx_train(const struct llama_model * model);
  263. LLAMA_API int llama_n_embd (const struct llama_model * model);
  264. // Get the model's RoPE frequency scaling factor
  265. LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model);
  266. // Functions to access the model's GGUF metadata scalar values
  267. // - The functions return the length of the string on success, or -1 on failure
  268. // - The output string is always null-terminated and cleared on failure
  269. // - GGUF array values are not supported by these functions
  270. // Get metadata value as a string by key name
  271. LLAMA_API int llama_model_meta_val_str(const struct llama_model * model, const char * key, char * buf, size_t buf_size);
  272. // Get the number of metadata key/value pairs
  273. LLAMA_API int llama_model_meta_count(const struct llama_model * model);
  274. // Get metadata key name by index
  275. LLAMA_API int llama_model_meta_key_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
  276. // Get metadata value as a string by index
  277. LLAMA_API int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, char * buf, size_t buf_size);
  278. // Get a string describing the model type
  279. LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
  280. // Returns the total size of all the tensors in the model in bytes
  281. LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
  282. // Returns the total number of parameters in the model
  283. LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
  284. // Get a llama model tensor
  285. LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
  286. // Returns 0 on success
  287. LLAMA_API int llama_model_quantize(
  288. const char * fname_inp,
  289. const char * fname_out,
  290. const llama_model_quantize_params * params);
  291. // Apply a LoRA adapter to a loaded model
  292. // path_base_model is the path to a higher quality model to use as a base for
  293. // the layers modified by the adapter. Can be NULL to use the current loaded model.
  294. // The model needs to be reloaded before applying a new adapter, otherwise the adapter
  295. // will be applied on top of the previous one
  296. // Returns 0 on success
  297. LLAMA_API DEPRECATED(int llama_apply_lora_from_file(
  298. struct llama_context * ctx,
  299. const char * path_lora,
  300. float scale,
  301. const char * path_base_model,
  302. int n_threads),
  303. "use llama_model_apply_lora_from_file instead");
  304. LLAMA_API int llama_model_apply_lora_from_file(
  305. const struct llama_model * model,
  306. const char * path_lora,
  307. float scale,
  308. const char * path_base_model,
  309. int n_threads);
  310. //
  311. // KV cache
  312. //
  313. // Information associated with an individual cell in the KV cache view.
  314. struct llama_kv_cache_view_cell {
  315. // The position for this cell. Takes KV cache shifts into account.
  316. // May be negative if the cell is not populated.
  317. llama_pos pos;
  318. };
  319. // An updateable view of the KV cache.
  320. struct llama_kv_cache_view {
  321. // Number of KV cache cells. This will be the same as the context size.
  322. int32_t n_cells;
  323. // Maximum number of sequences that can exist in a cell. It's not an error
  324. // if there are more sequences in a cell than this value, however they will
  325. // not be visible in the view cells_sequences.
  326. int32_t n_max_seq;
  327. // Number of tokens in the cache. For example, if there are two populated
  328. // cells, the first with 1 sequence id in it and the second with 2 sequence
  329. // ids then you'll have 3 tokens.
  330. int32_t token_count;
  331. // Number of populated cache cells.
  332. int32_t used_cells;
  333. // Maximum contiguous empty slots in the cache.
  334. int32_t max_contiguous;
  335. // Index to the start of the max_contiguous slot range. Can be negative
  336. // when cache is full.
  337. int32_t max_contiguous_idx;
  338. // Information for an individual cell.
  339. struct llama_kv_cache_view_cell * cells;
  340. // The sequences for each cell. There will be n_max_seq items per cell.
  341. llama_seq_id * cells_sequences;
  342. };
  343. // Create an empty KV cache view. (use only for debugging purposes)
  344. LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
  345. // Free a KV cache view. (use only for debugging purposes)
  346. LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
  347. // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
  348. LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
  349. // Returns the number of tokens in the KV cache (slow, use only for debug)
  350. // If a KV cell has multiple sequences assigned to it, it will be counted multiple times
  351. LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
  352. // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
  353. LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);
  354. // Clear the KV cache
  355. LLAMA_API void llama_kv_cache_clear(
  356. struct llama_context * ctx);
  357. // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
  358. // seq_id < 0 : match any sequence
  359. // p0 < 0 : [0, p1]
  360. // p1 < 0 : [p0, inf)
  361. LLAMA_API void llama_kv_cache_seq_rm(
  362. struct llama_context * ctx,
  363. llama_seq_id seq_id,
  364. llama_pos p0,
  365. llama_pos p1);
  366. // Copy all tokens that belong to the specified sequence to another sequence
  367. // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence
  368. // p0 < 0 : [0, p1]
  369. // p1 < 0 : [p0, inf)
  370. LLAMA_API void llama_kv_cache_seq_cp(
  371. struct llama_context * ctx,
  372. llama_seq_id seq_id_src,
  373. llama_seq_id seq_id_dst,
  374. llama_pos p0,
  375. llama_pos p1);
  376. // Removes all tokens that do not belong to the specified sequence
  377. LLAMA_API void llama_kv_cache_seq_keep(
  378. struct llama_context * ctx,
  379. llama_seq_id seq_id);
  380. // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
  381. // If the KV cache is RoPEd, the KV data is updated accordingly
  382. // p0 < 0 : [0, p1]
  383. // p1 < 0 : [p0, inf)
  384. LLAMA_API void llama_kv_cache_seq_shift(
  385. struct llama_context * ctx,
  386. llama_seq_id seq_id,
  387. llama_pos p0,
  388. llama_pos p1,
  389. llama_pos delta);
  390. //
  391. // State / sessions
  392. //
  393. // Returns the maximum size in bytes of the state (rng, logits, embedding
  394. // and kv_cache) - will often be smaller after compacting tokens
  395. LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
  396. // Copies the state to the specified destination address.
  397. // Destination needs to have allocated enough memory.
  398. // Returns the number of bytes copied
  399. LLAMA_API size_t llama_copy_state_data(
  400. struct llama_context * ctx,
  401. uint8_t * dst);
  402. // Set the state reading from the specified address
  403. // Returns the number of bytes read
  404. LLAMA_API size_t llama_set_state_data(
  405. struct llama_context * ctx,
  406. uint8_t * src);
  407. // Save/load session file
  408. LLAMA_API bool llama_load_session_file(
  409. struct llama_context * ctx,
  410. const char * path_session,
  411. llama_token * tokens_out,
  412. size_t n_token_capacity,
  413. size_t * n_token_count_out);
  414. LLAMA_API bool llama_save_session_file(
  415. struct llama_context * ctx,
  416. const char * path_session,
  417. const llama_token * tokens,
  418. size_t n_token_count);
  419. //
  420. // Decoding
  421. //
  422. // Run the llama inference to obtain the logits and probabilities for the next token(s).
  423. // tokens + n_tokens is the provided batch of new tokens to process
  424. // n_past is the number of tokens to use from previous eval calls
  425. // Returns 0 on success
  426. // DEPRECATED: use llama_decode() instead
  427. LLAMA_API DEPRECATED(int llama_eval(
  428. struct llama_context * ctx,
  429. llama_token * tokens,
  430. int32_t n_tokens,
  431. int n_past),
  432. "use llama_decode() instead");
  433. // Same as llama_eval, but use float matrix input directly.
  434. // DEPRECATED: use llama_decode() instead
  435. LLAMA_API DEPRECATED(int llama_eval_embd(
  436. struct llama_context * ctx,
  437. float * embd,
  438. int32_t n_tokens,
  439. int n_past),
  440. "use llama_decode() instead");
  441. // Return batch for single sequence of tokens starting at pos_0
  442. //
  443. // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
  444. //
  445. LLAMA_API struct llama_batch llama_batch_get_one(
  446. llama_token * tokens,
  447. int32_t n_tokens,
  448. llama_pos pos_0,
  449. llama_seq_id seq_id);
  450. // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
  451. // Each token can be assigned up to n_seq_max sequence ids
  452. // The batch has to be freed with llama_batch_free()
  453. // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
  454. // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
  455. // The rest of the llama_batch members are allocated with size n_tokens
  456. // All members are left uninitialized
  457. LLAMA_API struct llama_batch llama_batch_init(
  458. int32_t n_tokens,
  459. int32_t embd,
  460. int32_t n_seq_max);
  461. // Frees a batch of tokens allocated with llama_batch_init()
  462. LLAMA_API void llama_batch_free(struct llama_batch batch);
  463. // Positive return values does not mean a fatal error, but rather a warning.
  464. // 0 - success
  465. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  466. // < 0 - error
  467. LLAMA_API int llama_decode(
  468. struct llama_context * ctx,
  469. struct llama_batch batch);
  470. // Set the number of threads used for decoding
  471. // n_threads is the number of threads used for generation (single token)
  472. // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)
  473. LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch);
  474. // Token logits obtained from the last call to llama_eval()
  475. // The logits for the last token are stored in the last row
  476. // Logits for which llama_batch.logits[i] == 0 are undefined
  477. // Rows: n_tokens provided with llama_batch
  478. // Cols: n_vocab
  479. LLAMA_API float * llama_get_logits(struct llama_context * ctx);
  480. // Logits for the ith token. Equivalent to:
  481. // llama_get_logits(ctx) + i*n_vocab
  482. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
  483. // Get the embeddings for the input
  484. // shape: [n_embd] (1-dimensional)
  485. LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
  486. //
  487. // Vocab
  488. //
  489. LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token);
  490. LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
  491. LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
  492. // Special tokens
  493. LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
  494. LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
  495. LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
  496. // Returns -1 if unknown, 1 for true or 0 for false.
  497. LLAMA_API int llama_add_bos_token(const struct llama_model * model);
  498. // Returns -1 if unknown, 1 for true or 0 for false.
  499. LLAMA_API int llama_add_eos_token(const struct llama_model * model);
  500. // codellama infill tokens
  501. LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
  502. LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
  503. LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
  504. LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle
  505. //
  506. // Tokenization
  507. //
  508. /// @details Convert the provided text into tokens.
  509. /// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
  510. /// @return Returns the number of tokens on success, no more than n_max_tokens
  511. /// @return Returns a negative number on failure - the number of tokens that would have been returned
  512. /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext.
  513. /// Does not insert a leading space.
  514. LLAMA_API int llama_tokenize(
  515. const struct llama_model * model,
  516. const char * text,
  517. int text_len,
  518. llama_token * tokens,
  519. int n_max_tokens,
  520. bool add_bos,
  521. bool special);
  522. // Token Id -> Piece.
  523. // Uses the vocabulary in the provided context.
  524. // Does not write null terminator to the buffer.
  525. // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens.
  526. LLAMA_API int llama_token_to_piece(
  527. const struct llama_model * model,
  528. llama_token token,
  529. char * buf,
  530. int length);
  531. //
  532. // Grammar
  533. //
  534. LLAMA_API struct llama_grammar * llama_grammar_init(
  535. const llama_grammar_element ** rules,
  536. size_t n_rules,
  537. size_t start_rule_index);
  538. LLAMA_API void llama_grammar_free(struct llama_grammar * grammar);
  539. LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
  540. //
  541. // Sampling functions
  542. //
  543. // Sets the current rng seed.
  544. LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
  545. /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
  546. /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
  547. LLAMA_API void llama_sample_repetition_penalties(
  548. struct llama_context * ctx,
  549. llama_token_data_array * candidates,
  550. const llama_token * last_tokens,
  551. size_t penalty_last_n,
  552. float penalty_repeat,
  553. float penalty_freq,
  554. float penalty_present);
  555. /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
  556. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
  557. /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
  558. /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
  559. LLAMA_API void llama_sample_classifier_free_guidance(
  560. struct llama_context * ctx,
  561. llama_token_data_array * candidates,
  562. struct llama_context * guidance_ctx,
  563. float scale);
  564. /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
  565. LLAMA_API void llama_sample_softmax(
  566. struct llama_context * ctx,
  567. llama_token_data_array * candidates);
  568. /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  569. LLAMA_API void llama_sample_top_k(
  570. struct llama_context * ctx,
  571. llama_token_data_array * candidates,
  572. int k,
  573. size_t min_keep);
  574. /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
  575. LLAMA_API void llama_sample_top_p(
  576. struct llama_context * ctx,
  577. llama_token_data_array * candidates,
  578. float p,
  579. size_t min_keep);
  580. /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
  581. LLAMA_API void llama_sample_min_p(
  582. struct llama_context * ctx,
  583. llama_token_data_array * candidates,
  584. float p,
  585. size_t min_keep);
  586. /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
  587. LLAMA_API void llama_sample_tail_free(
  588. struct llama_context * ctx,
  589. llama_token_data_array * candidates,
  590. float z,
  591. size_t min_keep);
  592. /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
  593. LLAMA_API void llama_sample_typical(
  594. struct llama_context * ctx,
  595. llama_token_data_array * candidates,
  596. float p,
  597. size_t min_keep);
  598. LLAMA_API void llama_sample_temp(
  599. struct llama_context * ctx,
  600. llama_token_data_array * candidates,
  601. float temp);
  602. LLAMA_API DEPRECATED(void llama_sample_temperature(
  603. struct llama_context * ctx,
  604. llama_token_data_array * candidates,
  605. float temp),
  606. "use llama_sample_temp instead");
  607. /// @details Apply constraints from grammar
  608. LLAMA_API void llama_sample_grammar(
  609. struct llama_context * ctx,
  610. llama_token_data_array * candidates,
  611. const struct llama_grammar * grammar);
  612. /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  613. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
  614. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
  615. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
  616. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
  617. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
  618. LLAMA_API llama_token llama_sample_token_mirostat(
  619. struct llama_context * ctx,
  620. llama_token_data_array * candidates,
  621. float tau,
  622. float eta,
  623. int m,
  624. float * mu);
  625. /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
  626. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
  627. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
  628. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
  629. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
  630. LLAMA_API llama_token llama_sample_token_mirostat_v2(
  631. struct llama_context * ctx,
  632. llama_token_data_array * candidates,
  633. float tau,
  634. float eta,
  635. float * mu);
  636. /// @details Selects the token with the highest probability.
  637. /// Does not compute the token probabilities. Use llama_sample_softmax() instead.
  638. LLAMA_API llama_token llama_sample_token_greedy(
  639. struct llama_context * ctx,
  640. llama_token_data_array * candidates);
  641. /// @details Randomly selects a token from the candidates based on their probabilities.
  642. LLAMA_API llama_token llama_sample_token(
  643. struct llama_context * ctx,
  644. llama_token_data_array * candidates);
  645. /// @details Accepts the sampled token into the grammar
  646. LLAMA_API void llama_grammar_accept_token(
  647. struct llama_context * ctx,
  648. struct llama_grammar * grammar,
  649. llama_token token);
  650. //
  651. // Beam search
  652. //
  653. struct llama_beam_view {
  654. const llama_token * tokens;
  655. size_t n_tokens;
  656. float p; // Cumulative beam probability (renormalized relative to all beams)
  657. bool eob; // Callback should set this to true when a beam is at end-of-beam.
  658. };
  659. // Passed to beam_search_callback function.
  660. // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
  661. // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
  662. // These pointers are valid only during the synchronous callback, so should not be saved.
  663. struct llama_beams_state {
  664. struct llama_beam_view * beam_views;
  665. size_t n_beams; // Number of elements in beam_views[].
  666. size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
  667. bool last_call; // True iff this is the last callback invocation.
  668. };
  669. // Type of pointer to the beam_search_callback function.
  670. // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
  671. // passed back to beam_search_callback. This avoids having to use global variables in the callback.
  672. typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
  673. /// @details Deterministically returns entire sentence constructed by a beam search.
  674. /// @param ctx Pointer to the llama_context.
  675. /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
  676. /// @param callback_data A pointer that is simply passed back to callback.
  677. /// @param n_beams Number of beams to use.
  678. /// @param n_past Number of tokens already evaluated.
  679. /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
  680. LLAMA_API void llama_beam_search(
  681. struct llama_context * ctx,
  682. llama_beam_search_callback_fn_t callback,
  683. void * callback_data,
  684. size_t n_beams,
  685. int n_past,
  686. int n_predict);
  687. // Performance information
  688. LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx);
  689. LLAMA_API void llama_print_timings(struct llama_context * ctx);
  690. LLAMA_API void llama_reset_timings(struct llama_context * ctx);
  691. // Print system information
  692. LLAMA_API const char * llama_print_system_info(void);
  693. // Set callback for all future logging events.
  694. // If this is not called, or NULL is supplied, everything is output on stderr.
  695. LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data);
  696. LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
  697. #ifdef __cplusplus
  698. }
  699. #endif
  700. // Internal API to be implemented by llama.cpp and used by tests/benchmarks only
  701. #ifdef LLAMA_API_INTERNAL
  702. #include <vector>
  703. #include <string>
  704. struct ggml_tensor;
  705. const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
  706. struct llama_context * ctx
  707. );
  708. #endif // LLAMA_API_INTERNAL
  709. #endif // LLAMA_H