sampling.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. #pragma once
  2. #include "llama.h"
  3. #include <string>
  4. #include <vector>
  5. enum gpt_sampler_type {
  6. GPT_SAMPLER_TYPE_NONE = 0,
  7. GPT_SAMPLER_TYPE_TOP_K = 1,
  8. GPT_SAMPLER_TYPE_TOP_P = 2,
  9. GPT_SAMPLER_TYPE_MIN_P = 3,
  10. GPT_SAMPLER_TYPE_TFS_Z = 4,
  11. GPT_SAMPLER_TYPE_TYPICAL_P = 5,
  12. GPT_SAMPLER_TYPE_TEMPERATURE = 6,
  13. };
  14. // sampling parameters
  15. struct gpt_sampler_params {
  16. uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
  17. int32_t n_prev = 64; // number of previous tokens to remember
  18. int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
  19. int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
  20. int32_t top_k = 40; // <= 0 to use vocab size
  21. float top_p = 0.95f; // 1.0 = disabled
  22. float min_p = 0.05f; // 0.0 = disabled
  23. float tfs_z = 1.00f; // 1.0 = disabled
  24. float typ_p = 1.00f; // typical_p, 1.0 = disabled
  25. float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
  26. float dynatemp_range = 0.00f; // 0.0 = disabled
  27. float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
  28. int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
  29. float penalty_repeat = 1.00f; // 1.0 = disabled
  30. float penalty_freq = 0.00f; // 0.0 = disabled
  31. float penalty_present = 0.00f; // 0.0 = disabled
  32. int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
  33. float mirostat_tau = 5.00f; // target entropy
  34. float mirostat_eta = 0.10f; // learning rate
  35. bool penalize_nl = false; // consider newlines as a repeatable token
  36. bool ignore_eos = false;
  37. std::vector<enum gpt_sampler_type> samplers = {
  38. GPT_SAMPLER_TYPE_TOP_K,
  39. GPT_SAMPLER_TYPE_TFS_Z,
  40. GPT_SAMPLER_TYPE_TYPICAL_P,
  41. GPT_SAMPLER_TYPE_TOP_P,
  42. GPT_SAMPLER_TYPE_MIN_P,
  43. GPT_SAMPLER_TYPE_TEMPERATURE
  44. };
  45. std::string grammar; // optional BNF-like grammar to constrain sampling
  46. std::vector<llama_logit_bias> logit_bias; // logit biases to apply
  47. // print the parameters into a string
  48. std::string print() const;
  49. };
  50. // gpt_sampler extends llama_sampler with additional functionality:
  51. //
  52. // - grammar support
  53. // - custom sampler logic based on the parameters
  54. // - history of the last accepted tokens
  55. // - performance metrics
  56. //
  57. // This goal is to have a common implementation of the sampling logic shared across the examples.
  58. // For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
  59. // complex (top-k, top-p, etc).
  60. //
  61. // Another example is related to the grammar. In general, the grammar constraints applied on the full
  62. // vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
  63. // token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
  64. // grammar constraints are applied to the full vocabulary and the token is resampled.
  65. //
  66. // The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
  67. // be moved into the core llama library.
  68. //
  69. // For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
  70. // This can be used to access the probabilities of the rest of the non-sampled tokens.
  71. //
  72. // TODO: measure grammar performance
  73. //
  74. struct gpt_sampler;
  75. // llama_sampler API overloads
  76. struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
  77. void gpt_sampler_free(struct gpt_sampler * gsmpl);
  78. // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
  79. void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
  80. void gpt_sampler_reset (struct gpt_sampler * gsmpl);
  81. struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);
  82. // arguments can be nullptr to skip printing
  83. void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
  84. // extended sampling implementation:
  85. //
  86. // - set logits
  87. // - apply the configured sampler chain
  88. // - check if the token fits the grammar (if any)
  89. // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
  90. //
  91. // if grammar_first is true, the grammar is applied before the samplers (slower)
  92. // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
  93. //
  94. llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
  95. // helpers
  96. // access the internal list of current candidate tokens
  97. llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
  98. // get the last accepted token
  99. llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
  100. // print the sampler chain into a string
  101. std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
  102. // get a string representation of the last accepted tokens
  103. std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
  104. char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
  105. std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
  106. std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
  107. std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);