sampling.h 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #pragma once
  2. #include "llama.h"
  3. #include "common.h"
  4. #include <string>
  5. #include <vector>
  6. // gpt_sampler extends llama_sampler with additional functionality:
  7. //
  8. // - grammar support
  9. // - custom sampler logic based on the parameters
  10. // - history of the last accepted tokens
  11. // - performance metrics
  12. //
  13. // This goal is to have a common implementation of the sampling logic shared across the examples.
  14. // For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
  15. // complex (top-k, top-p, etc).
  16. //
  17. // Another example is related to the grammar. In general, the grammar constraints applied on the full
  18. // vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
  19. // token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
  20. // grammar constraints are applied to the full vocabulary and the token is resampled.
  21. //
  22. // The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
  23. // be moved into the core llama library.
  24. //
  25. // For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
  26. // This can be used to access the probabilities of the rest of the non-sampled tokens.
  27. //
  28. // TODO: measure grammar performance
  29. //
  30. struct gpt_sampler;
  31. // llama_sampler API overloads
  32. struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
  33. void gpt_sampler_free(struct gpt_sampler * gsmpl);
  34. // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
  35. void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
  36. void gpt_sampler_reset (struct gpt_sampler * gsmpl);
  37. struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);
  38. // arguments can be nullptr to skip printing
  39. void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
  40. // extended sampling implementation:
  41. //
  42. // - set logits
  43. // - apply the configured sampler chain
  44. // - check if the token fits the grammar (if any)
  45. // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
  46. //
  47. // if grammar_first is true, the grammar is applied before the samplers (slower)
  48. // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
  49. //
  50. llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
  51. uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);
  52. // helpers
  53. // access the internal list of current candidate tokens
  54. llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
  55. // get the last accepted token
  56. llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
  57. // print the sampler chain into a string
  58. std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
  59. // get a string representation of the last accepted tokens
  60. std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
  61. char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
  62. std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
  63. std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
  64. std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);