k_quants.h 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #pragma once
  2. #include "ggml.h"
  3. #include <stdint.h>
  4. #include <assert.h>
  5. #include <stddef.h>
  6. // Super-block size
  7. #ifdef GGML_QKK_64
  8. #define QK_K 64
  9. #define K_SCALE_SIZE 4
  10. #else
  11. #define QK_K 256
  12. #define K_SCALE_SIZE 12
  13. #endif
  14. #ifndef static_assert
  15. #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
  16. #define static_assert(cond, msg) _Static_assert(cond, msg)
  17. #else
  18. #define static_assert(cond, msg) struct global_scope_noop_trick
  19. #endif
  20. #endif
  21. //
  22. // Super-block quantization structures
  23. //
  24. // 2-bit quantization
  25. // weight is represented as x = a * q + b
  26. // 16 blocks of 16 elemenets each
  27. // Effectively 2.5625 bits per weight
  28. typedef struct {
  29. uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
  30. uint8_t qs[QK_K/4]; // quants
  31. ggml_fp16_t d; // super-block scale for quantized scales
  32. ggml_fp16_t dmin; // super-block scale for quantized mins
  33. } block_q2_K;
  34. static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
  35. // 3-bit quantization
  36. // weight is represented as x = a * q
  37. // 16 blocks of 16 elemenets each
  38. // Effectively 3.4375 bits per weight
  39. #ifdef GGML_QKK_64
  40. typedef struct {
  41. uint8_t hmask[QK_K/8]; // quants - high bit
  42. uint8_t qs[QK_K/4]; // quants - low 2 bits
  43. uint8_t scales[2];
  44. ggml_fp16_t d; // super-block scale
  45. } block_q3_K;
  46. static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding");
  47. #else
  48. typedef struct {
  49. uint8_t hmask[QK_K/8]; // quants - high bit
  50. uint8_t qs[QK_K/4]; // quants - low 2 bits
  51. uint8_t scales[12]; // scales, quantized with 6 bits
  52. ggml_fp16_t d; // super-block scale
  53. } block_q3_K;
  54. static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding");
  55. #endif
  56. // 4-bit quantization
  57. // 16 blocks of 32 elements each
  58. // weight is represented as x = a * q + b
  59. // Effectively 4.5 bits per weight
  60. #ifdef GGML_QKK_64
  61. typedef struct {
  62. ggml_fp16_t d[2]; // super-block scales/mins
  63. uint8_t scales[2]; // 4-bit block scales/mins
  64. uint8_t qs[QK_K/2]; // 4--bit quants
  65. } block_q4_K;
  66. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
  67. #else
  68. typedef struct {
  69. ggml_fp16_t d; // super-block scale for quantized scales
  70. ggml_fp16_t dmin; // super-block scale for quantized mins
  71. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  72. uint8_t qs[QK_K/2]; // 4--bit quants
  73. } block_q4_K;
  74. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding");
  75. #endif
  76. // 5-bit quantization
  77. // 16 blocks of 32 elements each
  78. // weight is represented as x = a * q + b
  79. // Effectively 5.5 bits per weight
  80. #ifdef GGML_QKK_64
  81. typedef struct {
  82. ggml_fp16_t d; // super-block scale
  83. int8_t scales[QK_K/16]; // 8-bit block scales
  84. uint8_t qh[QK_K/8]; // quants, high bit
  85. uint8_t qs[QK_K/2]; // quants, low 4 bits
  86. } block_q5_K;
  87. static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
  88. #else
  89. typedef struct {
  90. ggml_fp16_t d; // super-block scale for quantized scales
  91. ggml_fp16_t dmin; // super-block scale for quantized mins
  92. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  93. uint8_t qh[QK_K/8]; // quants, high bit
  94. uint8_t qs[QK_K/2]; // quants, low 4 bits
  95. } block_q5_K;
  96. static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
  97. #endif
  98. // 6-bit quantization
  99. // weight is represented as x = a * q
  100. // 16 blocks of 16 elemenets each
  101. // Effectively 6.5625 bits per weight
  102. typedef struct {
  103. uint8_t ql[QK_K/2]; // quants, lower 4 bits
  104. uint8_t qh[QK_K/4]; // quants, upper 2 bits
  105. int8_t scales[QK_K/16]; // scales, quantized with 8 bits
  106. ggml_fp16_t d; // super-block scale
  107. } block_q6_K;
  108. static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
  109. // This is only used for intermediate quantization and dot products
  110. typedef struct {
  111. float d; // delta
  112. int8_t qs[QK_K]; // quants
  113. int16_t bsums[QK_K/16]; // sum of quants in groups of 16
  114. } block_q8_K;
  115. static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
  116. // Quantization
  117. void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
  118. void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
  119. void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
  120. void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
  121. void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
  122. void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
  123. void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
  124. void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
  125. void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
  126. void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
  127. void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
  128. void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
  129. // Dequantization
  130. void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
  131. void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
  132. void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
  133. void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
  134. void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
  135. void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
  136. // Dot product
  137. void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  138. void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  139. void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  140. void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  141. void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  142. // Quantization with histogram collection
  143. size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
  144. size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
  145. size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
  146. size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
  147. size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);