ggml-quants.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. #pragma once
  2. #include "ggml-impl.h"
  3. // GGML internal header
  4. #include <stdint.h>
  5. #include <stddef.h>
  6. #define QK4_0 32
  7. typedef struct {
  8. ggml_fp16_t d; // delta
  9. uint8_t qs[QK4_0 / 2]; // nibbles / quants
  10. } block_q4_0;
  11. static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
  12. #define QK4_1 32
  13. typedef struct {
  14. ggml_fp16_t d; // delta
  15. ggml_fp16_t m; // min
  16. uint8_t qs[QK4_1 / 2]; // nibbles / quants
  17. } block_q4_1;
  18. static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding");
  19. #define QK5_0 32
  20. typedef struct {
  21. ggml_fp16_t d; // delta
  22. uint8_t qh[4]; // 5-th bit of quants
  23. uint8_t qs[QK5_0 / 2]; // nibbles / quants
  24. } block_q5_0;
  25. static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
  26. #define QK5_1 32
  27. typedef struct {
  28. ggml_fp16_t d; // delta
  29. ggml_fp16_t m; // min
  30. uint8_t qh[4]; // 5-th bit of quants
  31. uint8_t qs[QK5_1 / 2]; // nibbles / quants
  32. } block_q5_1;
  33. static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
  34. #define QK8_0 32
  35. typedef struct {
  36. ggml_fp16_t d; // delta
  37. int8_t qs[QK8_0]; // quants
  38. } block_q8_0;
  39. static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
  40. #define QK8_1 32
  41. typedef struct {
  42. float d; // delta
  43. float s; // d * sum(qs[i])
  44. int8_t qs[QK8_1]; // quants
  45. } block_q8_1;
  46. static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding");
  47. //
  48. // Super-block quantization structures
  49. //
  50. // Super-block size
  51. #ifdef GGML_QKK_64
  52. #define QK_K 64
  53. #define K_SCALE_SIZE 4
  54. #else
  55. #define QK_K 256
  56. #define K_SCALE_SIZE 12
  57. #endif
  58. // 2-bit quantization
  59. // weight is represented as x = a * q + b
  60. // 16 blocks of 16 elements each
  61. // Effectively 2.625 bits per weight
  62. typedef struct {
  63. uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
  64. uint8_t qs[QK_K/4]; // quants
  65. ggml_fp16_t d; // super-block scale for quantized scales
  66. ggml_fp16_t dmin; // super-block scale for quantized mins
  67. } block_q2_K;
  68. static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
  69. // 3-bit quantization
  70. // weight is represented as x = a * q
  71. // 16 blocks of 16 elements each
  72. // Effectively 3.4375 bits per weight
  73. #ifdef GGML_QKK_64
  74. typedef struct {
  75. uint8_t hmask[QK_K/8]; // quants - high bit
  76. uint8_t qs[QK_K/4]; // quants - low 2 bits
  77. uint8_t scales[2];
  78. ggml_fp16_t d; // super-block scale
  79. } block_q3_K;
  80. static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
  81. #else
  82. typedef struct {
  83. uint8_t hmask[QK_K/8]; // quants - high bit
  84. uint8_t qs[QK_K/4]; // quants - low 2 bits
  85. uint8_t scales[12]; // scales, quantized with 6 bits
  86. ggml_fp16_t d; // super-block scale
  87. } block_q3_K;
  88. static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
  89. #endif
  90. // 4-bit quantization
  91. // 8 blocks of 32 elements each
  92. // weight is represented as x = a * q + b
  93. // Effectively 4.5 bits per weight
  94. #ifdef GGML_QKK_64
  95. typedef struct {
  96. ggml_fp16_t d[2]; // super-block scales/mins
  97. uint8_t scales[2]; // 4-bit block scales/mins
  98. uint8_t qs[QK_K/2]; // 4--bit quants
  99. } block_q4_K;
  100. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
  101. #else
  102. typedef struct {
  103. ggml_fp16_t d; // super-block scale for quantized scales
  104. ggml_fp16_t dmin; // super-block scale for quantized mins
  105. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  106. uint8_t qs[QK_K/2]; // 4--bit quants
  107. } block_q4_K;
  108. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
  109. #endif
  110. // 5-bit quantization
  111. // 8 blocks of 32 elements each
  112. // weight is represented as x = a * q + b
  113. // Effectively 5.5 bits per weight
  114. #ifdef GGML_QKK_64
  115. typedef struct {
  116. ggml_fp16_t d; // super-block scale
  117. int8_t scales[QK_K/16]; // 8-bit block scales
  118. uint8_t qh[QK_K/8]; // quants, high bit
  119. uint8_t qs[QK_K/2]; // quants, low 4 bits
  120. } block_q5_K;
  121. static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
  122. #else
  123. typedef struct {
  124. ggml_fp16_t d; // super-block scale for quantized scales
  125. ggml_fp16_t dmin; // super-block scale for quantized mins
  126. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  127. uint8_t qh[QK_K/8]; // quants, high bit
  128. uint8_t qs[QK_K/2]; // quants, low 4 bits
  129. } block_q5_K;
  130. static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
  131. #endif
  132. // 6-bit quantization
  133. // weight is represented as x = a * q
  134. // 16 blocks of 16 elements each
  135. // Effectively 6.5625 bits per weight
  136. typedef struct {
  137. uint8_t ql[QK_K/2]; // quants, lower 4 bits
  138. uint8_t qh[QK_K/4]; // quants, upper 2 bits
  139. int8_t scales[QK_K/16]; // scales, quantized with 8 bits
  140. ggml_fp16_t d; // super-block scale
  141. } block_q6_K;
  142. static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
  143. // This is only used for intermediate quantization and dot products
  144. typedef struct {
  145. float d; // delta
  146. int8_t qs[QK_K]; // quants
  147. int16_t bsums[QK_K/16]; // sum of quants in groups of 16
  148. } block_q8_K;
  149. static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
  150. // (Almost) "true" 2-bit quantization.
  151. // Due to the need to use blocks as per ggml dsign, it ends up using
  152. // 2.0625 bpw because of the 16-bit scale for each block of 256.
  153. typedef struct {
  154. ggml_fp16_t d;
  155. uint16_t qs[QK_K/8];
  156. } block_iq2_xxs;
  157. static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding");
  158. // 2.3125 bpw quants
  159. typedef struct {
  160. ggml_fp16_t d;
  161. uint16_t qs[QK_K/8];
  162. uint8_t scales[QK_K/32];
  163. } block_iq2_xs;
  164. static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
  165. // Quantization
  166. void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k);
  167. void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k);
  168. void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k);
  169. void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k);
  170. void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k);
  171. void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k);
  172. void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
  173. void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
  174. void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
  175. void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
  176. void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
  177. void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
  178. void quantize_row_q4_0(const float * restrict x, void * restrict y, int k);
  179. void quantize_row_q4_1(const float * restrict x, void * restrict y, int k);
  180. void quantize_row_q5_0(const float * restrict x, void * restrict y, int k);
  181. void quantize_row_q5_1(const float * restrict x, void * restrict y, int k);
  182. void quantize_row_q8_0(const float * restrict x, void * restrict y, int k);
  183. void quantize_row_q8_1(const float * restrict x, void * restrict y, int k);
  184. void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
  185. void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
  186. void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
  187. void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
  188. void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
  189. void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
  190. // Dequantization
  191. void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k);
  192. void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k);
  193. void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k);
  194. void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k);
  195. void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k);
  196. //void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k);
  197. void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
  198. void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
  199. void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
  200. void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
  201. void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
  202. void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
  203. void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int k);
  204. void dequantize_row_iq2_xs (const block_iq2_xs * restrict x, float * restrict y, int k);
  205. // Dot product
  206. void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  207. void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  208. void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  209. void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  210. void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  211. void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  212. void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  213. void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  214. void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  215. void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  216. void ggml_vec_dot_iq2_xxs_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  217. void ggml_vec_dot_iq2_xs_q8_K (int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  218. //
  219. // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
  220. //
  221. size_t quantize_iq2_xxs(const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  222. size_t quantize_iq2_xs (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  223. size_t quantize_q2_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  224. size_t quantize_q3_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  225. size_t quantize_q4_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  226. size_t quantize_q5_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  227. size_t quantize_q6_K (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  228. size_t quantize_q4_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  229. size_t quantize_q4_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  230. size_t quantize_q5_0 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  231. size_t quantize_q5_1 (const float * src, void * dst, int nrows, int n_per_row, int64_t * hist, const float * imatrix);
  232. void iq2xs_init_impl(int grid_size);
  233. void iq2xs_free_impl(int grid_size);