sampling.h 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #pragma once
  2. #include "llama.h"
  3. #include "common.h"
  4. #include <string>
  5. #include <vector>
  6. // common_sampler extends llama_sampler with additional functionality:
  7. //
  8. // - grammar support
  9. // - custom sampler logic based on the parameters
  10. // - history of the last accepted tokens
  11. // - performance metrics
  12. //
  13. // This goal is to have a common implementation of the sampling logic shared across the examples.
  14. // For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
  15. // complex (top-k, top-p, etc).
  16. //
  17. // Another example is related to the grammar. In general, the grammar constraints applied on the full
  18. // vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
  19. // token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
  20. // grammar constraints are applied to the full vocabulary and the token is resampled.
  21. //
  22. // The common_sampler also maintains a container with the last accepted tokens. In the future, this can
  23. // be moved into the core llama library.
  24. //
  25. // For convenience, the common_sampler also maintains a container with the current candidate tokens.
  26. // This can be used to access the probabilities of the rest of the non-sampled tokens.
  27. //
  28. // TODO: measure grammar performance
  29. //
  30. struct common_sampler;
  31. // llama_sampler API overloads
  32. struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
  33. void common_sampler_free(struct common_sampler * gsmpl);
  34. // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
  35. void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
  36. void common_sampler_reset (struct common_sampler * gsmpl);
  37. struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
  38. // arguments can be nullptr to skip printing
  39. void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
  40. struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
  41. // extended sampling implementation:
  42. //
  43. // - set logits
  44. // - apply the configured sampler chain
  45. // - check if the token fits the grammar (if any)
  46. // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
  47. //
  48. llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
  49. // generalized version of common_sampler_sample
  50. //
  51. // will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
  52. // if the sampler disagrees at some point, we stop and return the accepted tokens up to now
  53. //
  54. // common_sampler_sample_n(gsmpl, ctx, { idx }, {});
  55. //
  56. // is equivalent to
  57. //
  58. // common_sampler_sample(gsmpl, ctx, idx);
  59. // common_sampler_accept(gsmpl, token, true);
  60. //
  61. // requires: idxs.size() == draft.size() + 1
  62. //
  63. // returns at least 1 token, up to idxs.size()
  64. //
  65. std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
  66. // assume idxs == [ 0, 1, 2, ..., draft.size() ]
  67. std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
  68. uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
  69. // helpers
  70. // access the internal list of current candidate tokens
  71. // if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
  72. // the .sorted flag of the result indicates whether the returned candidates are sorted
  73. llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
  74. // get the last accepted token
  75. llama_token common_sampler_last(const struct common_sampler * gsmpl);
  76. // print the sampler chain into a string
  77. std::string common_sampler_print(const struct common_sampler * gsmpl);
  78. // get a string representation of the last accepted tokens
  79. std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
  80. char common_sampler_type_to_chr(enum common_sampler_type cnstr);
  81. std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
  82. std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
  83. std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
  84. llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
  85. const char * grammar_kind, const char * grammar_data);
  86. struct common_sampler_deleter {
  87. void operator()(common_sampler * s) { common_sampler_free(s); }
  88. };
  89. typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;