test-sampling.cpp 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. #include "ggml.h"
  2. #include "llama.h"
  3. #ifdef NDEBUG
  4. #undef NDEBUG
  5. #endif
  6. #include <cmath>
  7. #include <numeric>
  8. #include <cassert>
  9. #include <vector>
  10. #include <algorithm>
  11. static void dump(const llama_token_data_array * candidates) {
  12. for (size_t i = 0; i < candidates->size; i++) {
  13. printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
  14. }
  15. }
  16. #define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
  17. static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
  18. size_t n_vocab = probs.size();
  19. std::vector<llama_token_data> candidates;
  20. candidates.reserve(n_vocab);
  21. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  22. float logit = log(probs[token_id]);
  23. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  24. }
  25. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  26. llama_sample_softmax(nullptr, &candidates_p);
  27. DUMP(&candidates_p);
  28. llama_sample_top_k(nullptr, &candidates_p, k, 1);
  29. DUMP(&candidates_p);
  30. GGML_ASSERT(candidates_p.size == expected_probs.size());
  31. for (size_t i = 0; i < candidates_p.size; i++) {
  32. GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
  33. }
  34. }
  35. static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
  36. size_t n_vocab = probs.size();
  37. std::vector<llama_token_data> candidates;
  38. candidates.reserve(n_vocab);
  39. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  40. float logit = log(probs[token_id]);
  41. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  42. }
  43. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  44. llama_sample_softmax(nullptr, &candidates_p);
  45. DUMP(&candidates_p);
  46. llama_sample_top_p(nullptr, &candidates_p, p, 1);
  47. DUMP(&candidates_p);
  48. GGML_ASSERT(candidates_p.size == expected_probs.size());
  49. for (size_t i = 0; i < candidates_p.size; i++) {
  50. GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  51. }
  52. }
  53. static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
  54. size_t n_vocab = probs.size();
  55. std::vector<llama_token_data> candidates;
  56. candidates.reserve(n_vocab);
  57. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  58. float logit = log(probs[token_id]);
  59. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  60. }
  61. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  62. DUMP(&candidates_p);
  63. llama_sample_tail_free(nullptr, &candidates_p, z, 1);
  64. DUMP(&candidates_p);
  65. GGML_ASSERT(candidates_p.size == expected_probs.size());
  66. for (size_t i = 0; i < candidates_p.size; i++) {
  67. GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  68. }
  69. }
  70. static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
  71. size_t n_vocab = probs.size();
  72. std::vector<llama_token_data> candidates;
  73. candidates.reserve(n_vocab);
  74. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  75. float logit = log(probs[token_id]);
  76. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  77. }
  78. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  79. DUMP(&candidates_p);
  80. llama_sample_typical(nullptr, &candidates_p, p, 1);
  81. DUMP(&candidates_p);
  82. GGML_ASSERT(candidates_p.size == expected_probs.size());
  83. for (size_t i = 0; i < candidates_p.size; i++) {
  84. GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  85. }
  86. }
  87. static void test_repetition_penalties(
  88. const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
  89. const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
  90. ) {
  91. GGML_ASSERT(probs.size() == expected_probs.size());
  92. size_t n_vocab = probs.size();
  93. std::vector<llama_token_data> candidates;
  94. candidates.reserve(n_vocab);
  95. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  96. float logit = log(probs[token_id]);
  97. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  98. }
  99. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  100. llama_sample_softmax(nullptr, &candidates_p);
  101. DUMP(&candidates_p);
  102. llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
  103. llama_sample_softmax(nullptr, &candidates_p);
  104. DUMP(&candidates_p);
  105. GGML_ASSERT(candidates_p.size == expected_probs.size());
  106. for (size_t i = 0; i < candidates_p.size; i++) {
  107. GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  108. }
  109. }
  110. int main(void) {
  111. ggml_time_init();
  112. test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
  113. test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
  114. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
  115. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
  116. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
  117. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
  118. test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
  119. test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
  120. test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
  121. test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
  122. test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
  123. test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
  124. test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
  125. test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
  126. test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
  127. test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
  128. test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
  129. printf("OK\n");
  130. return 0;
  131. }