quantize_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. package quant
  2. import (
  3. "math/rand"
  4. "testing"
  5. "time"
  6. )
  7. func BenchmarkQuantizeQ8K(b *testing.B) {
  8. // 1M floats (typical small tensor)
  9. data := make([]float32, 1024*1024)
  10. r := rand.New(rand.NewSource(42))
  11. for i := range data {
  12. data[i] = r.Float32()*2 - 1 // [-1, 1]
  13. }
  14. b.ResetTimer()
  15. for i := 0; i < b.N; i++ {
  16. _ = QuantizeQ8K(data)
  17. }
  18. }
  19. func BenchmarkQuantizeQ6K(b *testing.B) {
  20. data := make([]float32, 1024*1024)
  21. r := rand.New(rand.NewSource(42))
  22. for i := range data {
  23. data[i] = r.Float32()*2 - 1
  24. }
  25. b.ResetTimer()
  26. for i := 0; i < b.N; i++ {
  27. _ = QuantizeQ6K(data)
  28. }
  29. }
  30. func BenchmarkQuantizeQ4K(b *testing.B) {
  31. data := make([]float32, 1024*1024)
  32. r := rand.New(rand.NewSource(42))
  33. for i := range data {
  34. data[i] = r.Float32()*2 - 1
  35. }
  36. b.ResetTimer()
  37. for i := 0; i < b.N; i++ {
  38. _ = QuantizeQ4K(data)
  39. }
  40. }
  41. func TestQuantizeQ8K_Basic(t *testing.T) {
  42. // Simple test: 256 elements
  43. data := make([]float32, 256)
  44. for i := range data {
  45. data[i] = float32(i-128) / 128.0 // [-1, ~1]
  46. }
  47. start := time.Now()
  48. result := QuantizeQ8K(data)
  49. elapsed := time.Since(start)
  50. // Expect 292 bytes (1 block)
  51. if len(result) != 292 {
  52. t.Errorf("Expected 292 bytes, got %d", len(result))
  53. }
  54. t.Logf("Q8K: 256 floats -> %d bytes in %v", len(result), elapsed)
  55. }
  56. func TestQuantizeQ6K_Basic(t *testing.T) {
  57. data := make([]float32, 256)
  58. for i := range data {
  59. data[i] = float32(i-128) / 128.0
  60. }
  61. start := time.Now()
  62. result := QuantizeQ6K(data)
  63. elapsed := time.Since(start)
  64. // Expect 210 bytes (1 block)
  65. if len(result) != 210 {
  66. t.Errorf("Expected 210 bytes, got %d", len(result))
  67. }
  68. t.Logf("Q6K: 256 floats -> %d bytes in %v", len(result), elapsed)
  69. }
  70. func TestQuantizeQ4K_Basic(t *testing.T) {
  71. data := make([]float32, 256)
  72. for i := range data {
  73. data[i] = float32(i-128) / 128.0
  74. }
  75. start := time.Now()
  76. result := QuantizeQ4K(data)
  77. elapsed := time.Since(start)
  78. // Expect 144 bytes (1 block)
  79. if len(result) != 144 {
  80. t.Errorf("Expected 144 bytes, got %d", len(result))
  81. }
  82. t.Logf("Q4K: 256 floats -> %d bytes in %v", len(result), elapsed)
  83. }
  84. func TestLargeQuantization(t *testing.T) {
  85. // Test with 4M elements (typical large weight matrix)
  86. size := 4 * 1024 * 1024
  87. data := make([]float32, size)
  88. r := rand.New(rand.NewSource(42))
  89. for i := range data {
  90. data[i] = r.Float32()*2 - 1
  91. }
  92. t.Run("Q8K_4M", func(t *testing.T) {
  93. start := time.Now()
  94. result := QuantizeQ8K(data)
  95. elapsed := time.Since(start)
  96. mbps := float64(size*4) / elapsed.Seconds() / (1024 * 1024)
  97. t.Logf("Q8K: %d floats (%.1f MB) -> %d bytes in %v (%.1f MB/s)",
  98. size, float64(size*4)/(1024*1024), len(result), elapsed, mbps)
  99. })
  100. t.Run("Q6K_4M", func(t *testing.T) {
  101. start := time.Now()
  102. result := QuantizeQ6K(data)
  103. elapsed := time.Since(start)
  104. mbps := float64(size*4) / elapsed.Seconds() / (1024 * 1024)
  105. t.Logf("Q6K: %d floats (%.1f MB) -> %d bytes in %v (%.1f MB/s)",
  106. size, float64(size*4)/(1024*1024), len(result), elapsed, mbps)
  107. })
  108. t.Run("Q4K_4M", func(t *testing.T) {
  109. start := time.Now()
  110. result := QuantizeQ4K(data)
  111. elapsed := time.Since(start)
  112. mbps := float64(size*4) / elapsed.Seconds() / (1024 * 1024)
  113. t.Logf("Q4K: %d floats (%.1f MB) -> %d bytes in %v (%.1f MB/s)",
  114. size, float64(size*4)/(1024*1024), len(result), elapsed, mbps)
  115. })
  116. }