k_quants.h 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #pragma once
  2. #include "ggml.h"
  3. #include <stdint.h>
  4. #include <assert.h>
  5. #include <stddef.h>
  6. // Super-block size
  7. #ifdef GGML_QKK_64
  8. #define QK_K 64
  9. #define K_SCALE_SIZE 4
  10. #else
  11. #define QK_K 256
  12. #define K_SCALE_SIZE 12
  13. #endif
  14. //
  15. // Super-block quantization structures
  16. //
  17. // 2-bit quantization
  18. // weight is represented as x = a * q + b
  19. // 16 blocks of 16 elemenets each
  20. // Effectively 2.5625 bits per weight
  21. typedef struct {
  22. uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
  23. uint8_t qs[QK_K/4]; // quants
  24. ggml_fp16_t d; // super-block scale for quantized scales
  25. ggml_fp16_t dmin; // super-block scale for quantized mins
  26. } block_q2_K;
  27. static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
  28. // 3-bit quantization
  29. // weight is represented as x = a * q
  30. // 16 blocks of 16 elemenets each
  31. // Effectively 3.4375 bits per weight
  32. #ifdef GGML_QKK_64
  33. typedef struct {
  34. uint8_t hmask[QK_K/8]; // quants - high bit
  35. uint8_t qs[QK_K/4]; // quants - low 2 bits
  36. uint8_t scales[2];
  37. ggml_fp16_t d; // super-block scale
  38. } block_q3_K;
  39. static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
  40. #else
  41. typedef struct {
  42. uint8_t hmask[QK_K/8]; // quants - high bit
  43. uint8_t qs[QK_K/4]; // quants - low 2 bits
  44. uint8_t scales[12]; // scales, quantized with 6 bits
  45. ggml_fp16_t d; // super-block scale
  46. } block_q3_K;
  47. static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
  48. #endif
  49. // 4-bit quantization
  50. // 16 blocks of 32 elements each
  51. // weight is represented as x = a * q + b
  52. // Effectively 4.5 bits per weight
  53. #ifdef GGML_QKK_64
  54. typedef struct {
  55. ggml_fp16_t d[2]; // super-block scales/mins
  56. uint8_t scales[2]; // 4-bit block scales/mins
  57. uint8_t qs[QK_K/2]; // 4--bit quants
  58. } block_q4_K;
  59. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
  60. #else
  61. typedef struct {
  62. ggml_fp16_t d; // super-block scale for quantized scales
  63. ggml_fp16_t dmin; // super-block scale for quantized mins
  64. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  65. uint8_t qs[QK_K/2]; // 4--bit quants
  66. } block_q4_K;
  67. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
  68. #endif
  69. // 5-bit quantization
  70. // 16 blocks of 32 elements each
  71. // weight is represented as x = a * q + b
  72. // Effectively 5.5 bits per weight
  73. #ifdef GGML_QKK_64
  74. typedef struct {
  75. ggml_fp16_t d; // super-block scale
  76. int8_t scales[QK_K/16]; // 8-bit block scales
  77. uint8_t qh[QK_K/8]; // quants, high bit
  78. uint8_t qs[QK_K/2]; // quants, low 4 bits
  79. } block_q5_K;
  80. static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
  81. #else
  82. typedef struct {
  83. ggml_fp16_t d; // super-block scale for quantized scales
  84. ggml_fp16_t dmin; // super-block scale for quantized mins
  85. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  86. uint8_t qh[QK_K/8]; // quants, high bit
  87. uint8_t qs[QK_K/2]; // quants, low 4 bits
  88. } block_q5_K;
  89. static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
  90. #endif
  91. // 6-bit quantization
  92. // weight is represented as x = a * q
  93. // 16 blocks of 16 elemenets each
  94. // Effectively 6.5625 bits per weight
  95. typedef struct {
  96. uint8_t ql[QK_K/2]; // quants, lower 4 bits
  97. uint8_t qh[QK_K/4]; // quants, upper 2 bits
  98. int8_t scales[QK_K/16]; // scales, quantized with 8 bits
  99. ggml_fp16_t d; // super-block scale
  100. } block_q6_K;
  101. static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
  102. // This is only used for intermediate quantization and dot products
  103. typedef struct {
  104. float d; // delta
  105. int8_t qs[QK_K]; // quants
  106. int16_t bsums[QK_K/16]; // sum of quants in groups of 16
  107. } block_q8_K;
  108. static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
  109. // Quantization
  110. void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
  111. void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
  112. void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
  113. void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
  114. void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
  115. void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
  116. void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
  117. void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
  118. void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
  119. void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
  120. void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
  121. void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
  122. // Dequantization
  123. void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
  124. void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
  125. void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
  126. void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
  127. void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
  128. void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
  129. // Dot product
  130. void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  131. void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  132. void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  133. void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  134. void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  135. // Quantization with histogram collection
  136. size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
  137. size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
  138. size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
  139. size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
  140. size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);