1
0

quant_test.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. package tests
  2. import (
  3. "math"
  4. "math/rand"
  5. "testing"
  6. "unsafe"
  7. "makarna/pkg/quant"
  8. "makarna/pkg/tensor"
  9. )
  10. // TestQ2K_RoundTrip quantizes patterns with Go impl and checks dequantization.
  11. func TestQ2K_RoundTrip(t *testing.T) {
  12. patterns := qkTestPatterns()
  13. for i, expected := range patterns {
  14. q := quant.QuantizeQ2K(expected)
  15. if len(q) != 84 {
  16. t.Fatalf("unexpected q2k buffer size %d", len(q))
  17. }
  18. block := (*tensor.BlockQ2_K)(unsafe.Pointer(&q[0]))
  19. var out [256]float32
  20. tensor.DequantizeQ2_K(block, out[:])
  21. mse := 0.0
  22. maxDiff := 0.0
  23. signMatches := 0
  24. for j := 0; j < 256; j++ {
  25. diff := math.Abs(float64(out[j] - expected[j]))
  26. mse += diff * diff
  27. if diff > maxDiff {
  28. maxDiff = diff
  29. }
  30. if (out[j] >= 0) == (expected[j] >= 0) {
  31. signMatches++
  32. }
  33. }
  34. mse /= 256.0
  35. signMatchRate := float64(signMatches) / 256.0
  36. desc := "Unknown"
  37. tolerance := 0.5
  38. minSignMatch := 0.85
  39. switch i {
  40. case 0:
  41. desc = "Gradient"
  42. tolerance = 0.1
  43. case 1:
  44. desc = "Random Normal"
  45. tolerance = 0.2
  46. case 2:
  47. desc = "Random Large"
  48. tolerance = 250.0
  49. minSignMatch = 0.70
  50. case 3:
  51. desc = "Sparse"
  52. tolerance = 0.1
  53. minSignMatch = 0.80
  54. }
  55. t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f, SignMatch=%.1f%%", i, desc, mse, maxDiff, signMatchRate*100)
  56. if mse > tolerance {
  57. t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance)
  58. }
  59. if signMatchRate < minSignMatch {
  60. t.Errorf("Block %d (%s) sign match too low: %.1f%% (min=%.1f%%)", i, desc, signMatchRate*100, minSignMatch*100)
  61. }
  62. }
  63. }
  64. // TestQ3K_GoldenIter generates patterns and quantizes with Go impl to avoid stale python goldens.
  65. func TestQ3K_GoldenIter(t *testing.T) {
  66. patterns := qkTestPatterns()
  67. for i, expected := range patterns {
  68. q := quant.QuantizeQ3K(expected)
  69. if len(q)%110 != 0 || len(q) == 0 {
  70. t.Fatalf("unexpected q3k buffer size %d", len(q))
  71. }
  72. block := (*tensor.BlockQ3_K)(unsafe.Pointer(&q[0]))
  73. var out [256]float32
  74. tensor.DequantizeQ3_K(block, out[:])
  75. mse := 0.0
  76. maxDiff := 0.0
  77. for j := 0; j < 256; j++ {
  78. diff := math.Abs(float64(out[j] - expected[j]))
  79. mse += diff * diff
  80. if diff > maxDiff {
  81. maxDiff = diff
  82. }
  83. }
  84. mse /= 256.0
  85. desc, tolerance := qkPatternInfo(i)
  86. t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f", i, desc, mse, maxDiff)
  87. if mse > tolerance {
  88. t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance)
  89. }
  90. }
  91. }
  92. func qkTestPatterns() [][]float32 {
  93. patterns := make([][]float32, 0, 4)
  94. // Pattern 1: Gradient [-0.5, 0.5]
  95. grad := make([]float32, 256)
  96. for i := 0; i < 256; i++ {
  97. grad[i] = -0.5 + float32(i)*(1.0/255.0)
  98. }
  99. patterns = append(patterns, grad)
  100. // Pattern 2: Random Normal (seed 42, std 0.05)
  101. rngNormal := rand.New(rand.NewSource(42))
  102. normal := make([]float32, 256)
  103. for i := range normal {
  104. normal[i] = float32(rngNormal.NormFloat64() * 0.05)
  105. }
  106. patterns = append(patterns, normal)
  107. // Pattern 3: Random Large (seed 123, std 50)
  108. rngLarge := rand.New(rand.NewSource(123))
  109. large := make([]float32, 256)
  110. for i := range large {
  111. large[i] = float32(rngLarge.NormFloat64() * 50.0)
  112. }
  113. patterns = append(patterns, large)
  114. // Pattern 4: Sparse (next values from same rngLarge)
  115. sparse := make([]float32, 256)
  116. for i := 0; i < 16; i++ {
  117. sparse[i*16] = float32(rngLarge.NormFloat64() * 0.1)
  118. }
  119. patterns = append(patterns, sparse)
  120. return patterns
  121. }
  122. func qkPatternInfo(i int) (desc string, tolerance float64) {
  123. desc = "Unknown"
  124. tolerance = 0.15
  125. switch i {
  126. case 0:
  127. desc = "Gradient"
  128. tolerance = 0.05
  129. case 1:
  130. desc = "Random Normal"
  131. tolerance = 0.1
  132. case 2:
  133. desc = "Random Large"
  134. tolerance = 3000.0
  135. case 3:
  136. desc = "Sparse"
  137. tolerance = 0.05
  138. }
  139. return
  140. }
  141. // TestQ6K_RoundTrip quantizes patterns with Go impl and checks dequantization.
  142. func TestQ6K_RoundTrip(t *testing.T) {
  143. patterns := qkTestPatterns()
  144. for i, expected := range patterns {
  145. q := quant.QuantizeQ6K(expected)
  146. if len(q) != 210 {
  147. t.Fatalf("unexpected q6k buffer size %d", len(q))
  148. }
  149. block := (*tensor.BlockQ6_K)(unsafe.Pointer(&q[0]))
  150. var out [256]float32
  151. tensor.DequantizeQ6_K(block, out[:])
  152. mse := 0.0
  153. maxDiff := 0.0
  154. for j := 0; j < 256; j++ {
  155. diff := math.Abs(float64(out[j] - expected[j]))
  156. mse += diff * diff
  157. if diff > maxDiff {
  158. maxDiff = diff
  159. }
  160. }
  161. mse /= 256.0
  162. desc := "Unknown"
  163. tolerance := 0.02
  164. switch i {
  165. case 0:
  166. desc = "Gradient"
  167. tolerance = 0.01
  168. case 1:
  169. desc = "Random Normal"
  170. tolerance = 0.015
  171. case 2:
  172. desc = "Random Large"
  173. tolerance = 2.0
  174. case 3:
  175. desc = "Sparse"
  176. tolerance = 0.01
  177. }
  178. t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f", i, desc, mse, maxDiff)
  179. if mse > tolerance {
  180. t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance)
  181. }
  182. }
  183. }
  184. // TestQ4K_RoundTrip quantizes patterns with Go impl and checks dequantization.
  185. func TestQ4K_RoundTrip(t *testing.T) {
  186. patterns := qkTestPatterns()
  187. for i, expected := range patterns {
  188. q := quant.QuantizeQ4K(expected)
  189. if len(q) != 144 {
  190. t.Fatalf("unexpected q4k buffer size %d", len(q))
  191. }
  192. block := (*tensor.BlockQ4_K)(unsafe.Pointer(&q[0]))
  193. var out [256]float32
  194. tensor.DequantizeQ4_K(block, out[:])
  195. mse := 0.0
  196. maxDiff := 0.0
  197. for j := 0; j < 256; j++ {
  198. diff := math.Abs(float64(out[j] - expected[j]))
  199. mse += diff * diff
  200. if diff > maxDiff {
  201. maxDiff = diff
  202. }
  203. }
  204. mse /= 256.0
  205. desc := "Unknown"
  206. tolerance := 0.05
  207. switch i {
  208. case 0:
  209. desc = "Gradient"
  210. case 1:
  211. desc = "Random Normal"
  212. case 2:
  213. desc = "Random Large"
  214. tolerance = 20.0
  215. case 3:
  216. desc = "Sparse"
  217. }
  218. t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f", i, desc, mse, maxDiff)
  219. if mse > tolerance {
  220. t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance)
  221. }
  222. }
  223. }
  224. // TestQ8K_RoundTrip quantizes patterns with Go impl and checks dequantization.
  225. func TestQ8K_RoundTrip(t *testing.T) {
  226. patterns := qkTestPatterns()
  227. for i, expected := range patterns {
  228. q := quant.QuantizeQ8K(expected)
  229. if len(q) != 292 {
  230. t.Fatalf("unexpected q8k buffer size %d", len(q))
  231. }
  232. block := (*tensor.BlockQ8_K)(unsafe.Pointer(&q[0]))
  233. var out [256]float32
  234. tensor.DequantizeQ8_K(block, out[:])
  235. mse := 0.0
  236. maxDiff := 0.0
  237. for j := 0; j < 256; j++ {
  238. diff := math.Abs(float64(out[j] - expected[j]))
  239. mse += diff * diff
  240. if diff > maxDiff {
  241. maxDiff = diff
  242. }
  243. }
  244. mse /= 256.0
  245. desc := "Unknown"
  246. tolerance := 0.01
  247. switch i {
  248. case 0:
  249. desc = "Gradient"
  250. case 1:
  251. desc = "Random Normal"
  252. case 2:
  253. desc = "Random Large"
  254. tolerance = 1.0
  255. case 3:
  256. desc = "Sparse"
  257. }
  258. t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f", i, desc, mse, maxDiff)
  259. if mse > tolerance {
  260. t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance)
  261. }
  262. }
  263. }