test-quantize-fns.cpp 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. // Unit tests for quantization specific functions - quantize, dequantize and dot product
  2. #include "ggml.h"
  3. #undef NDEBUG
  4. #include <assert.h>
  5. #include <math.h>
  6. #include <stdio.h>
  7. #include <string>
  8. #include <vector>
  9. const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001;
  10. const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002;
  11. const float MAX_DOT_PRODUCT_ERROR = 0.02;
  12. const char* RESULT_STR[] = {"ok", "FAILED"};
  13. // Generate synthetic data
  14. void generate_data(float offset, size_t n, float * dst) {
  15. for (size_t i = 0; i < n; i++) {
  16. dst[i] = 0.1 + 2*cosf(i + offset);
  17. }
  18. }
  19. // Calculate RMSE between two float arrays
  20. float array_rmse(const float * a1, const float * a2, size_t n) {
  21. double sum = 0;
  22. for (size_t i = 0; i < n; i++) {
  23. double diff = a1[i] - a2[i];
  24. sum += diff * diff;
  25. }
  26. return sqrtf(sum) / n;
  27. }
  28. // Total quantization error on test data
  29. float total_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
  30. std::vector<uint8_t> tmp_q(test_size);
  31. std::vector<float> tmp_out(test_size);
  32. qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
  33. qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
  34. return array_rmse(test_data, tmp_out.data(), test_size);
  35. }
  36. // Total quantization error on test data
  37. float reference_quantization_error(quantize_fns_t & qfns, size_t test_size, const float * test_data) {
  38. std::vector<uint8_t> tmp_q(test_size);
  39. std::vector<float> tmp_out(test_size);
  40. std::vector<float> tmp_out_ref(test_size);
  41. qfns.quantize_row_q(test_data, tmp_q.data(), test_size);
  42. qfns.dequantize_row_q(tmp_q.data(), tmp_out.data(), test_size);
  43. qfns.quantize_row_q_reference(test_data, tmp_q.data(), test_size);
  44. qfns.dequantize_row_q(tmp_q.data(), tmp_out_ref.data(), test_size);
  45. return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
  46. }
  47. float dot_product(const float * a1, const float * a2, size_t test_size) {
  48. double sum = 0;
  49. for (size_t i = 0; i < test_size; i++) {
  50. sum += a1[i] * a2[i];
  51. }
  52. return sum;
  53. }
  54. // Total dot product error
  55. float dot_product_error(quantize_fns_t & qfns, size_t test_size, const float * test_data1, const float *test_data2) {
  56. std::vector<uint8_t> tmp_q1(test_size);
  57. std::vector<uint8_t> tmp_q2(test_size*2);
  58. qfns.quantize_row_q(test_data1, tmp_q1.data(), test_size);
  59. qfns.quantize_row_q_dot(test_data2, tmp_q2.data(), test_size);
  60. float result = INFINITY;
  61. qfns.vec_dot_q(test_size, &result, tmp_q1.data(), tmp_q2.data());
  62. const float dot_ref = dot_product(test_data1, test_data2, test_size);
  63. return fabsf(result - dot_ref) / test_size;
  64. }
  65. int main(int argc, char * argv[]) {
  66. bool verbose = false;
  67. const size_t test_size = 32 * 128;
  68. std::string arg;
  69. for (int i = 1; i < argc; i++) {
  70. arg = argv[i];
  71. if (arg == "-v") {
  72. verbose = true;
  73. } else {
  74. fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
  75. return 1;
  76. }
  77. }
  78. std::vector<float> test_data(test_size);
  79. std::vector<float> test_data2(test_size);
  80. generate_data(0.0, test_data.size(), test_data.data());
  81. generate_data(1.0, test_data2.size(), test_data2.data());
  82. // Initialize GGML, ensures float conversion tables are initialized
  83. struct ggml_init_params ggml_params = {
  84. /* .mem_size = */ 1*1024,
  85. /* .mem_buffer = */ NULL,
  86. /* .no_alloc = */ true,
  87. };
  88. struct ggml_context * ctx = ggml_init(ggml_params);
  89. int num_failed = 0;
  90. bool failed = false;
  91. for (int i = 0; i < GGML_TYPE_COUNT; i++) {
  92. ggml_type type = (ggml_type) i;
  93. quantize_fns_t qfns = ggml_internal_get_quantize_fn(i);
  94. if (qfns.quantize_row_q) {
  95. const float total_error = total_quantization_error(qfns, test_size, test_data.data());
  96. failed = !(total_error < MAX_QUANTIZATION_TOTAL_ERROR);
  97. num_failed += failed;
  98. if (failed || verbose) {
  99. printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
  100. }
  101. const float reference_error = reference_quantization_error(qfns, test_size, test_data.data());
  102. failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
  103. num_failed += failed;
  104. if (failed || verbose) {
  105. printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
  106. }
  107. const float vec_dot_error = dot_product_error(qfns, test_size, test_data.data(), test_data2.data());
  108. failed = !(vec_dot_error < MAX_DOT_PRODUCT_ERROR);
  109. num_failed += failed;
  110. if (failed || verbose) {
  111. printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
  112. }
  113. }
  114. }
  115. if (num_failed || verbose) {
  116. printf("%d tests failed\n", num_failed);
  117. }
  118. ggml_free(ctx);
  119. return num_failed > 0;
  120. }