test-quantize-fns.cpp 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. // Unit tests for quantization specific functions - quantize, dequantize and dot product
  2. #include "ggml.h"
  3. #include "ggml-cpu.h"
  4. #undef NDEBUG
  5. #include <assert.h>
  6. #include <math.h>
  7. #include <stdio.h>
  8. #include <string>
  9. #include <vector>
  10. #if defined(_MSC_VER)
  11. #pragma warning(disable: 4244 4267) // possible loss of data
  12. #endif
  13. constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
  14. constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
  15. constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f;
  16. constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
  17. constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
  18. constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
  19. constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
  20. constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
  21. constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;
  22. static const char* RESULT_STR[] = {"ok", "FAILED"};
  23. // Generate synthetic data
  24. static void generate_data(float offset, size_t n, float * dst) {
  25. for (size_t i = 0; i < n; i++) {
  26. dst[i] = 0.1 + 2*cosf(i + offset);
  27. }
  28. }
  29. // Calculate RMSE between two float arrays
  30. static float array_rmse(const float * a1, const float * a2, size_t n) {
  31. double sum = 0;
  32. for (size_t i = 0; i < n; i++) {
  33. double diff = a1[i] - a2[i];
  34. sum += diff * diff;
  35. }
  36. return sqrtf(sum) / n;
  37. }
  38. // Total quantization error on test data
  39. static float total_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
  40. std::vector<uint8_t> tmp_q(2*test_size);
  41. std::vector<float> tmp_out(test_size);
  42. qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
  43. qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
  44. return array_rmse(test_data, tmp_out.data(), test_size);
  45. }
  46. // Total quantization error on test data
  47. static float reference_quantization_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data) {
  48. std::vector<uint8_t> tmp_q(2*test_size);
  49. std::vector<float> tmp_out(test_size);
  50. std::vector<float> tmp_out_ref(test_size);
  51. // FIXME: why is done twice?
  52. qfns_cpu->from_float(test_data, tmp_q.data(), test_size);
  53. qfns->to_float(tmp_q.data(), tmp_out.data(), test_size);
  54. qfns->from_float_ref(test_data, tmp_q.data(), test_size);
  55. qfns->to_float(tmp_q.data(), tmp_out_ref.data(), test_size);
  56. return array_rmse(tmp_out.data(), tmp_out_ref.data(), test_size);
  57. }
  58. static float dot_product(const float * a1, const float * a2, size_t test_size) {
  59. double sum = 0;
  60. for (size_t i = 0; i < test_size; i++) {
  61. sum += a1[i] * a2[i];
  62. }
  63. return sum;
  64. }
  65. // Total dot product error
  66. static float dot_product_error(const ggml_type_traits * qfns, const ggml_type_traits_cpu * qfns_cpu, size_t test_size, const float * test_data1, const float * test_data2) {
  67. GGML_UNUSED(qfns);
  68. std::vector<uint8_t> tmp_q1(2*test_size);
  69. std::vector<uint8_t> tmp_q2(2*test_size);
  70. const auto * vdot = ggml_get_type_traits_cpu(qfns_cpu->vec_dot_type);
  71. qfns_cpu->from_float(test_data1, tmp_q1.data(), test_size);
  72. vdot->from_float(test_data2, tmp_q2.data(), test_size);
  73. float result = INFINITY;
  74. qfns_cpu->vec_dot(test_size, &result, 0, tmp_q1.data(), 0, tmp_q2.data(), 0, 1);
  75. const float dot_ref = dot_product(test_data1, test_data2, test_size);
  76. return fabsf(result - dot_ref) / test_size;
  77. }
  78. int main(int argc, char * argv[]) {
  79. bool verbose = false;
  80. const size_t test_size = 32 * 128;
  81. std::string arg;
  82. for (int i = 1; i < argc; i++) {
  83. arg = argv[i];
  84. if (arg == "-v") {
  85. verbose = true;
  86. } else {
  87. fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
  88. return 1;
  89. }
  90. }
  91. std::vector<float> test_data(test_size);
  92. std::vector<float> test_data2(test_size);
  93. generate_data(0.0, test_data.size(), test_data.data());
  94. generate_data(1.0, test_data2.size(), test_data2.data());
  95. // Initialize GGML, ensures float conversion tables are initialized
  96. struct ggml_init_params ggml_params = {
  97. /* .mem_size = */ 1*1024,
  98. /* .mem_buffer = */ NULL,
  99. /* .no_alloc = */ true,
  100. };
  101. struct ggml_context * ctx = ggml_init(ggml_params);
  102. int num_failed = 0;
  103. bool failed = false;
  104. for (int i = 0; i < GGML_TYPE_COUNT; i++) {
  105. ggml_type type = (ggml_type) i;
  106. const auto * qfns = ggml_get_type_traits(type);
  107. const auto * qfns_cpu = ggml_get_type_traits_cpu(type);
  108. // deprecated - skip
  109. if (qfns->blck_size == 0) {
  110. continue;
  111. }
  112. const ggml_type ei = (ggml_type)i;
  113. printf("Testing %s\n", ggml_type_name((ggml_type) i));
  114. ggml_quantize_init(ei);
  115. if (qfns_cpu->from_float && qfns->to_float) {
  116. const float total_error = total_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
  117. const float max_quantization_error =
  118. type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
  119. type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
  120. type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
  121. type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
  122. type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
  123. type == GGML_TYPE_IQ3_S ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
  124. type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR;
  125. failed = !(total_error < max_quantization_error);
  126. num_failed += failed;
  127. if (failed || verbose) {
  128. printf("%5s absolute quantization error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], total_error);
  129. }
  130. const float reference_error = reference_quantization_error(qfns, qfns_cpu, test_size, test_data.data());
  131. failed = !(reference_error < MAX_QUANTIZATION_REFERENCE_ERROR);
  132. num_failed += failed;
  133. if (failed || verbose) {
  134. printf("%5s reference implementation error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], reference_error);
  135. }
  136. const float vec_dot_error = dot_product_error(qfns, qfns_cpu, test_size, test_data.data(), test_data2.data());
  137. const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
  138. type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
  139. ? MAX_DOT_PRODUCT_ERROR_LOWBIT
  140. : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
  141. ? MAX_DOT_PRODUCT_ERROR_TERNARY
  142. : MAX_DOT_PRODUCT_ERROR;
  143. failed = !(vec_dot_error < max_allowed_error);
  144. num_failed += failed;
  145. if (failed || verbose) {
  146. printf("%5s dot product error: %s (%f)\n", ggml_type_name(type), RESULT_STR[failed], vec_dot_error);
  147. }
  148. }
  149. }
  150. if (num_failed || verbose) {
  151. printf("%d tests failed\n", num_failed);
  152. }
  153. ggml_free(ctx);
  154. return num_failed > 0;
  155. }