test-sampling.cpp 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. #include "ggml.h"
  2. #include "llama.h"
  3. #ifdef NDEBUG
  4. #undef NDEBUG
  5. #endif
  6. #include <cmath>
  7. #include <numeric>
  8. #include <cassert>
  9. #include <iostream>
  10. #include <vector>
  11. #include <algorithm>
  12. static void dump(const llama_token_data_array * candidates) {
  13. for (size_t i = 0; i < candidates->size; i++) {
  14. printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
  15. }
  16. }
  17. #define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
  18. static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
  19. size_t n_vocab = probs.size();
  20. std::vector<llama_token_data> candidates;
  21. candidates.reserve(n_vocab);
  22. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  23. float logit = log(probs[token_id]);
  24. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  25. }
  26. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  27. llama_sample_softmax(nullptr, &candidates_p);
  28. DUMP(&candidates_p);
  29. llama_sample_top_k(nullptr, &candidates_p, k, 1);
  30. DUMP(&candidates_p);
  31. assert(candidates_p.size == expected_probs.size());
  32. for (size_t i = 0; i < candidates_p.size; i++) {
  33. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
  34. }
  35. }
  36. static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
  37. size_t n_vocab = probs.size();
  38. std::vector<llama_token_data> candidates;
  39. candidates.reserve(n_vocab);
  40. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  41. float logit = log(probs[token_id]);
  42. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  43. }
  44. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  45. llama_sample_softmax(nullptr, &candidates_p);
  46. DUMP(&candidates_p);
  47. llama_sample_top_p(nullptr, &candidates_p, p, 1);
  48. DUMP(&candidates_p);
  49. assert(candidates_p.size == expected_probs.size());
  50. for (size_t i = 0; i < candidates_p.size; i++) {
  51. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  52. }
  53. }
  54. static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
  55. size_t n_vocab = probs.size();
  56. std::vector<llama_token_data> candidates;
  57. candidates.reserve(n_vocab);
  58. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  59. float logit = log(probs[token_id]);
  60. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  61. }
  62. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  63. DUMP(&candidates_p);
  64. llama_sample_tail_free(nullptr, &candidates_p, z, 1);
  65. DUMP(&candidates_p);
  66. assert(candidates_p.size == expected_probs.size());
  67. for (size_t i = 0; i < candidates_p.size; i++) {
  68. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  69. }
  70. }
  71. static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
  72. size_t n_vocab = probs.size();
  73. std::vector<llama_token_data> candidates;
  74. candidates.reserve(n_vocab);
  75. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  76. float logit = log(probs[token_id]);
  77. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  78. }
  79. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  80. DUMP(&candidates_p);
  81. llama_sample_typical(nullptr, &candidates_p, p, 1);
  82. DUMP(&candidates_p);
  83. assert(candidates_p.size == expected_probs.size());
  84. for (size_t i = 0; i < candidates_p.size; i++) {
  85. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  86. }
  87. }
  88. static void test_repetition_penalty(
  89. const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
  90. const std::vector<float> & expected_probs, float penalty
  91. ) {
  92. assert(probs.size() == expected_probs.size());
  93. size_t n_vocab = probs.size();
  94. std::vector<llama_token_data> candidates;
  95. candidates.reserve(n_vocab);
  96. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  97. float logit = log(probs[token_id]);
  98. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  99. }
  100. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  101. llama_sample_softmax(nullptr, &candidates_p);
  102. DUMP(&candidates_p);
  103. llama_sample_repetition_penalty(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), penalty);
  104. llama_sample_softmax(nullptr, &candidates_p);
  105. DUMP(&candidates_p);
  106. assert(candidates_p.size == expected_probs.size());
  107. for (size_t i = 0; i < candidates_p.size; i++) {
  108. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
  109. }
  110. }
  111. static void test_frequency_presence_penalty(
  112. const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
  113. const std::vector<float> & expected_probs, float alpha_frequency, float alpha_presence
  114. ) {
  115. assert(probs.size() == expected_probs.size());
  116. size_t n_vocab = probs.size();
  117. std::vector<llama_token_data> candidates;
  118. candidates.reserve(n_vocab);
  119. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  120. float logit = log(probs[token_id]);
  121. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  122. }
  123. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  124. llama_sample_softmax(nullptr, &candidates_p);
  125. // DUMP(&candidates_p);
  126. llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence);
  127. llama_sample_softmax(nullptr, &candidates_p);
  128. // DUMP(&candidates_p);
  129. assert(candidates_p.size == expected_probs.size());
  130. for (size_t i = 0; i < candidates_p.size; i++) {
  131. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  132. }
  133. }
  134. int main(void) {
  135. ggml_time_init();
  136. test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
  137. test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
  138. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
  139. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
  140. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
  141. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
  142. test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
  143. test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
  144. test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
  145. test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
  146. test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
  147. test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f);
  148. test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
  149. test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
  150. test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 5.0f, 5.0f);
  151. test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f);
  152. test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f);
  153. printf("OK\n");
  154. return 0;
  155. }