1
0

test-sampling.cpp 7.8 KB


  1. #include "ggml.h"
  2. #include "llama.h"
  3. #ifdef NDEBUG
  4. #undef NDEBUG
  5. #endif
  6. #include <cmath>
  7. #include <numeric>
  8. #include <cassert>
  9. #include <iostream>
  10. #include <vector>
  11. #include <algorithm>
  12. void dump(const llama_token_data_array * candidates) {
  13. for (size_t i = 0; i < candidates->size; i++) {
  14. printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
  15. }
  16. }
  17. #define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
  18. void test_top_k(const std::vector<float> & probs,
  19. const std::vector<float> & expected_probs,
  20. int k) {
  21. size_t n_vocab = probs.size();
  22. std::vector<llama_token_data> candidates;
  23. candidates.reserve(n_vocab);
  24. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  25. float logit = log(probs[token_id]);
  26. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  27. }
  28. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  29. llama_sample_softmax(nullptr, &candidates_p);
  30. DUMP(&candidates_p);
  31. llama_sample_top_k(nullptr, &candidates_p, k, 1);
  32. DUMP(&candidates_p);
  33. assert(candidates_p.size == expected_probs.size());
  34. for (size_t i = 0; i < candidates_p.size; i++) {
  35. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
  36. }
  37. }
  38. void test_top_p(const std::vector<float> & probs,
  39. const std::vector<float> & expected_probs,
  40. float p) {
  41. size_t n_vocab = probs.size();
  42. std::vector<llama_token_data> candidates;
  43. candidates.reserve(n_vocab);
  44. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  45. float logit = log(probs[token_id]);
  46. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  47. }
  48. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  49. llama_sample_softmax(nullptr, &candidates_p);
  50. DUMP(&candidates_p);
  51. llama_sample_top_p(nullptr, &candidates_p, p, 1);
  52. DUMP(&candidates_p);
  53. assert(candidates_p.size == expected_probs.size());
  54. for (size_t i = 0; i < candidates_p.size; i++) {
  55. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  56. }
  57. }
  58. void test_tfs(const std::vector<float> & probs,
  59. const std::vector<float> & expected_probs,
  60. float z) {
  61. size_t n_vocab = probs.size();
  62. std::vector<llama_token_data> candidates;
  63. candidates.reserve(n_vocab);
  64. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  65. float logit = log(probs[token_id]);
  66. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  67. }
  68. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  69. DUMP(&candidates_p);
  70. llama_sample_tail_free(nullptr, &candidates_p, z, 1);
  71. DUMP(&candidates_p);
  72. assert(candidates_p.size == expected_probs.size());
  73. for (size_t i = 0; i < candidates_p.size; i++) {
  74. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  75. }
  76. }
  77. void test_typical(const std::vector<float> & probs,
  78. const std::vector<float> & expected_probs,
  79. float p) {
  80. size_t n_vocab = probs.size();
  81. std::vector<llama_token_data> candidates;
  82. candidates.reserve(n_vocab);
  83. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  84. float logit = log(probs[token_id]);
  85. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  86. }
  87. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  88. DUMP(&candidates_p);
  89. llama_sample_typical(nullptr, &candidates_p, p, 1);
  90. DUMP(&candidates_p);
  91. assert(candidates_p.size == expected_probs.size());
  92. for (size_t i = 0; i < candidates_p.size; i++) {
  93. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  94. }
  95. }
  96. void test_repetition_penalty(
  97. const std::vector<float> & probs,
  98. const std::vector<llama_token> & last_tokens,
  99. const std::vector<float> & expected_probs,
  100. float penalty) {
  101. assert(probs.size() == expected_probs.size());
  102. size_t n_vocab = probs.size();
  103. std::vector<llama_token_data> candidates;
  104. candidates.reserve(n_vocab);
  105. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  106. float logit = log(probs[token_id]);
  107. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  108. }
  109. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  110. llama_sample_softmax(nullptr, &candidates_p);
  111. DUMP(&candidates_p);
  112. llama_sample_repetition_penalty(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), penalty);
  113. llama_sample_softmax(nullptr, &candidates_p);
  114. DUMP(&candidates_p);
  115. assert(candidates_p.size == expected_probs.size());
  116. for (size_t i = 0; i < candidates_p.size; i++) {
  117. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
  118. }
  119. }
  120. void test_frequency_presence_penalty(
  121. const std::vector<float> & probs,
  122. const std::vector<llama_token> & last_tokens,
  123. const std::vector<float> & expected_probs,
  124. float alpha_frequency, float alpha_presence) {
  125. assert(probs.size() == expected_probs.size());
  126. size_t n_vocab = probs.size();
  127. std::vector<llama_token_data> candidates;
  128. candidates.reserve(n_vocab);
  129. for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
  130. float logit = log(probs[token_id]);
  131. candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
  132. }
  133. llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
  134. llama_sample_softmax(nullptr, &candidates_p);
  135. // DUMP(&candidates_p);
  136. llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence);
  137. llama_sample_softmax(nullptr, &candidates_p);
  138. // DUMP(&candidates_p);
  139. assert(candidates_p.size == expected_probs.size());
  140. for (size_t i = 0; i < candidates_p.size; i++) {
  141. assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
  142. }
  143. }
  144. int main(void) {
  145. ggml_time_init();
  146. test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
  147. test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
  148. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
  149. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
  150. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f);
  151. test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
  152. test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
  153. test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
  154. test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
  155. test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
  156. test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
  157. test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f);
  158. test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
  159. test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
  160. test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 5.0f, 5.0f);
  161. test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f);
  162. test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f);
  163. printf("OK\n");
  164. return 0;
  165. }