test-sampling.cpp 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. #include "ggml.h"
  2. #include "llama.h"
  3. #ifdef NDEBUG
  4. #undef NDEBUG
  5. #endif
  6. #include <numeric>
  7. #include <cassert>
  8. #include <iostream>
  9. #include <vector>
  10. #include <algorithm>
  11. void dump(const llama_token_data_array * candidates) {
  12. for (size_t i = 0; i < candidates->size; i++) {
  13. printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
  14. }
  15. }
  16. #define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
  17. void test_top_k(const std::vector<float> & probs,
  18. const std::vector<float> & expected_probs,
  19. int k) {
  20. size_t n_vocab = probs.size();
  21. std::vector<llama_token_data> candidates;
  22. candidates.reserve(n_vocab);
  23. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  24. float logit = log(probs[token_id]);
  25. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  26. }
  27. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  28. llama_sample_softmax(nullptr, &candidates_p);
  29. DUMP(&candidates_p);
  30. llama_sample_top_k(nullptr, &candidates_p, k, 1);
  31. DUMP(&candidates_p);
  32. assert(candidates_p.size == expected_probs.size());
  33. for (size_t i = 0; i < candidates_p.size; i++) {
  34. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
  35. }
  36. }
  37. void test_top_p(const std::vector<float> & probs,
  38. const std::vector<float> & expected_probs,
  39. float p) {
  40. size_t n_vocab = probs.size();
  41. std::vector<llama_token_data> candidates;
  42. candidates.reserve(n_vocab);
  43. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  44. float logit = log(probs[token_id]);
  45. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  46. }
  47. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  48. llama_sample_softmax(nullptr, &candidates_p);
  49. DUMP(&candidates_p);
  50. llama_sample_top_p(nullptr, &candidates_p, p, 1);
  51. DUMP(&candidates_p);
  52. assert(candidates_p.size == expected_probs.size());
  53. for (size_t i = 0; i < candidates_p.size; i++) {
  54. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  55. }
  56. }
  57. void test_tfs(const std::vector<float> & probs,
  58. const std::vector<float> & expected_probs,
  59. float z) {
  60. size_t n_vocab = probs.size();
  61. std::vector<llama_token_data> candidates;
  62. candidates.reserve(n_vocab);
  63. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  64. float logit = log(probs[token_id]);
  65. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  66. }
  67. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  68. DUMP(&candidates_p);
  69. llama_sample_tail_free(nullptr, &candidates_p, z, 1);
  70. DUMP(&candidates_p);
  71. assert(candidates_p.size == expected_probs.size());
  72. for (size_t i = 0; i < candidates_p.size; i++) {
  73. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  74. }
  75. }
  76. void test_typical(const std::vector<float> & probs,
  77. const std::vector<float> & expected_probs,
  78. float p) {
  79. size_t n_vocab = probs.size();
  80. std::vector<llama_token_data> candidates;
  81. candidates.reserve(n_vocab);
  82. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  83. float logit = log(probs[token_id]);
  84. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  85. }
  86. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  87. DUMP(&candidates_p);
  88. llama_sample_typical(nullptr, &candidates_p, p, 1);
  89. DUMP(&candidates_p);
  90. assert(candidates_p.size == expected_probs.size());
  91. for (size_t i = 0; i < candidates_p.size; i++) {
  92. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  93. }
  94. }
  95. void test_repetition_penalty(
  96. const std::vector<float> & probs,
  97. const std::vector<llama_token> & last_tokens,
  98. const std::vector<float> & expected_probs,
  99. float penalty) {
  100. assert(probs.size() == expected_probs.size());
  101. size_t n_vocab = probs.size();
  102. std::vector<llama_token_data> candidates;
  103. candidates.reserve(n_vocab);
  104. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  105. float logit = log(probs[token_id]);
  106. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  107. }
  108. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  109. llama_sample_softmax(nullptr, &candidates_p);
  110. DUMP(&candidates_p);
  111. llama_sample_repetition_penalty(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), penalty);
  112. llama_sample_softmax(nullptr, &candidates_p);
  113. DUMP(&candidates_p);
  114. assert(candidates_p.size == expected_probs.size());
  115. for (size_t i = 0; i < candidates_p.size; i++) {
  116. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
  117. }
  118. }
  119. void test_frequency_presence_penalty(
  120. const std::vector<float> & probs,
  121. const std::vector<llama_token> & last_tokens,
  122. const std::vector<float> & expected_probs,
  123. float alpha_frequency, float alpha_presence) {
  124. assert(probs.size() == expected_probs.size());
  125. size_t n_vocab = probs.size();
  126. std::vector<llama_token_data> candidates;
  127. candidates.reserve(n_vocab);
  128. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  129. float logit = log(probs[token_id]);
  130. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  131. }
  132. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  133. llama_sample_softmax(nullptr, &candidates_p);
  134. // DUMP(&candidates_p);
  135. llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence);
  136. llama_sample_softmax(nullptr, &candidates_p);
  137. // DUMP(&candidates_p);
  138. assert(candidates_p.size == expected_probs.size());
  139. for (size_t i = 0; i < candidates_p.size; i++) {
  140. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  141. }
  142. }
  143. int main(void) {
  144. ggml_time_init();
  145. test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1);
  146. test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3);
  147. test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4}, 0);
  148. test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3}, 0.7);
  149. test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2, 0.1}, 1);
  150. test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3}, 0.25);
  151. test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.75);
  152. test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.99);
  153. test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5);
  154. test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5);
  155. test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.25, 0.25, 0.25, 0.25, 0}, 50.0);
  156. test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.5, 0.5, 0, 0, 0}, 50.0);
  157. test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.5, 0.5, 0, 0, 0}, 50.0);
  158. test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0);
  159. test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0);
  160. test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.499977, 0.499977, 0.000023, 0.000023, 0.000000}, 5.0, 5.0);
  161. printf("OK\n");
  162. }