k_quants.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. #pragma once
  2. #include "ggml.h"
  3. #include <stdint.h>
  4. #include <assert.h>
  5. #include <stddef.h>
  6. // Super-block size
  7. #define QK_K 256
  8. //
  9. // Super-block quantization structures
  10. //
  11. // 2-bit quantization
  12. // weight is represented as x = a * q + b
  13. // 16 blocks of 16 elemenets each
  14. // Effectively 2.5625 bits per weight
  15. typedef struct {
  16. uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
  17. uint8_t qs[QK_K/4]; // quants
  18. ggml_fp16_t d; // super-block scale for quantized scales
  19. ggml_fp16_t dmin; // super-block scale for quantized mins
  20. } block_q2_K;
  21. static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
  22. // 3-bit quantization
  23. // weight is represented as x = a * q
  24. // 16 blocks of 16 elemenets each
  25. // Effectively 3.4375 bits per weight
  26. typedef struct {
  27. uint8_t hmask[QK_K/8]; // quants - high bit
  28. uint8_t qs[QK_K/4]; // quants - low 2 bits
  29. uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
  30. ggml_fp16_t d; // super-block scale
  31. } block_q3_K;
  32. static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
  33. // 4-bit quantization
  34. // 16 blocks of 32 elements each
  35. // weight is represented as x = a * q + b
  36. // Effectively 4.5 bits per weight
  37. typedef struct {
  38. ggml_fp16_t d; // super-block scale for quantized scales
  39. ggml_fp16_t dmin; // super-block scale for quantized mins
  40. uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
  41. uint8_t qs[QK_K/2]; // 4--bit quants
  42. } block_q4_K;
  43. static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
  44. // 5-bit quantization
  45. // 16 blocks of 32 elements each
  46. // weight is represented as x = a * q + b
  47. // Effectively 5.5 bits per weight
  48. typedef struct {
  49. ggml_fp16_t d; // super-block scale for quantized scales
  50. ggml_fp16_t dmin; // super-block scale for quantized mins
  51. uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
  52. uint8_t qh[QK_K/8]; // quants, high bit
  53. uint8_t qs[QK_K/2]; // quants, low 4 bits
  54. } block_q5_K;
  55. static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
  56. // 6-bit quantization
  57. // weight is represented as x = a * q
  58. // 16 blocks of 16 elemenets each
  59. // Effectively 6.5625 bits per weight
  60. typedef struct {
  61. uint8_t ql[QK_K/2]; // quants, lower 4 bits
  62. uint8_t qh[QK_K/4]; // quants, upper 2 bits
  63. int8_t scales[QK_K/16]; // scales, quantized with 8 bits
  64. ggml_fp16_t d; // super-block scale
  65. } block_q6_K;
  66. static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding");
  67. // This is only used for intermediate quantization and dot products
  68. typedef struct {
  69. float d; // delta
  70. int8_t qs[QK_K]; // quants
  71. int16_t bsums[QK_K/16]; // sum of quants in groups of 16
  72. } block_q8_K;
  73. static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding");
  74. // Quantization
  75. void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k);
  76. void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k);
  77. void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k);
  78. void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k);
  79. void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k);
  80. void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k);
  81. void quantize_row_q2_K(const float * restrict x, void * restrict y, int k);
  82. void quantize_row_q3_K(const float * restrict x, void * restrict y, int k);
  83. void quantize_row_q4_K(const float * restrict x, void * restrict y, int k);
  84. void quantize_row_q5_K(const float * restrict x, void * restrict y, int k);
  85. void quantize_row_q6_K(const float * restrict x, void * restrict y, int k);
  86. void quantize_row_q8_K(const float * restrict x, void * restrict y, int k);
  87. // Dequantization
  88. void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k);
  89. void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k);
  90. void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k);
  91. void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k);
  92. void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
  93. void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
  94. // Dot product
  95. void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  96. void ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  97. void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  98. void ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  99. void ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
  100. // Quantization with histogram collection
  101. size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist);
  102. size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist);
  103. size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist);
  104. size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist);
  105. size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist);