package tests import ( "math" "math/rand" "testing" "unsafe" "makarna/pkg/quant" "makarna/pkg/tensor" ) // TestQ2K_RoundTrip quantizes patterns with Go impl and checks dequantization. func TestQ2K_RoundTrip(t *testing.T) { patterns := qkTestPatterns() for i, expected := range patterns { q := quant.QuantizeQ2K(expected) if len(q) != 84 { t.Fatalf("unexpected q2k buffer size %d", len(q)) } block := (*tensor.BlockQ2_K)(unsafe.Pointer(&q[0])) var out [256]float32 tensor.DequantizeQ2_K(block, out[:]) mse := 0.0 maxDiff := 0.0 signMatches := 0 for j := 0; j < 256; j++ { diff := math.Abs(float64(out[j] - expected[j])) mse += diff * diff if diff > maxDiff { maxDiff = diff } if (out[j] >= 0) == (expected[j] >= 0) { signMatches++ } } mse /= 256.0 signMatchRate := float64(signMatches) / 256.0 desc := "Unknown" tolerance := 0.5 minSignMatch := 0.85 switch i { case 0: desc = "Gradient" tolerance = 0.1 case 1: desc = "Random Normal" tolerance = 0.2 case 2: desc = "Random Large" tolerance = 250.0 minSignMatch = 0.70 case 3: desc = "Sparse" tolerance = 0.1 minSignMatch = 0.80 } t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f, SignMatch=%.1f%%", i, desc, mse, maxDiff, signMatchRate*100) if mse > tolerance { t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance) } if signMatchRate < minSignMatch { t.Errorf("Block %d (%s) sign match too low: %.1f%% (min=%.1f%%)", i, desc, signMatchRate*100, minSignMatch*100) } } } // TestQ3K_GoldenIter generates patterns and quantizes with Go impl to avoid stale python goldens. func TestQ3K_GoldenIter(t *testing.T) { patterns := qkTestPatterns() for i, expected := range patterns { q := quant.QuantizeQ3K(expected) if len(q)%110 != 0 || len(q) == 0 { t.Fatalf("unexpected q3k buffer size %d", len(q)) } block := (*tensor.BlockQ3_K)(unsafe.Pointer(&q[0])) var out [256]float32 tensor.DequantizeQ3_K(block, out[:]) mse := 0.0 maxDiff := 0.0 for j := 0; j < 256; j++ { diff := math.Abs(float64(out[j] - expected[j])) mse += diff * diff if diff > maxDiff { maxDiff = diff } } mse /= 256.0 desc, tolerance := qkPatternInfo(i) t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f", i, desc, mse, maxDiff) if mse > tolerance { t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance) } } } func qkTestPatterns() [][]float32 { patterns := make([][]float32, 0, 4) // Pattern 1: Gradient [-0.5, 0.5] grad := make([]float32, 256) for i := 0; i < 256; i++ { grad[i] = -0.5 + float32(i)*(1.0/255.0) } patterns = append(patterns, grad) // Pattern 2: Random Normal (seed 42, std 0.05) rngNormal := rand.New(rand.NewSource(42)) normal := make([]float32, 256) for i := range normal { normal[i] = float32(rngNormal.NormFloat64() * 0.05) } patterns = append(patterns, normal) // Pattern 3: Random Large (seed 123, std 50) rngLarge := rand.New(rand.NewSource(123)) large := make([]float32, 256) for i := range large { large[i] = float32(rngLarge.NormFloat64() * 50.0) } patterns = append(patterns, large) // Pattern 4: Sparse (next values from same rngLarge) sparse := make([]float32, 256) for i := 0; i < 16; i++ { sparse[i*16] = float32(rngLarge.NormFloat64() * 0.1) } patterns = append(patterns, sparse) return patterns } func qkPatternInfo(i int) (desc string, tolerance float64) { desc = "Unknown" tolerance = 0.15 switch i { case 0: desc = "Gradient" tolerance = 0.05 case 1: desc = "Random Normal" tolerance = 0.1 case 2: desc = "Random Large" tolerance = 3000.0 case 3: desc = "Sparse" tolerance = 0.05 } return } // TestQ6K_RoundTrip quantizes patterns with Go impl and checks dequantization. func TestQ6K_RoundTrip(t *testing.T) { patterns := qkTestPatterns() for i, expected := range patterns { q := quant.QuantizeQ6K(expected) if len(q) != 210 { t.Fatalf("unexpected q6k buffer size %d", len(q)) } block := (*tensor.BlockQ6_K)(unsafe.Pointer(&q[0])) var out [256]float32 tensor.DequantizeQ6_K(block, out[:]) mse := 0.0 maxDiff := 0.0 for j := 0; j < 256; j++ { diff := math.Abs(float64(out[j] - expected[j])) mse += diff * diff if diff > maxDiff { maxDiff = diff } } mse /= 256.0 desc := "Unknown" tolerance := 0.02 switch i { case 0: desc = "Gradient" tolerance = 0.01 case 1: desc = "Random Normal" tolerance = 0.015 case 2: desc = "Random Large" tolerance = 2.0 case 3: desc = "Sparse" tolerance = 0.01 } t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f", i, desc, mse, maxDiff) if mse > tolerance { t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance) } } } // TestQ4K_RoundTrip quantizes patterns with Go impl and checks dequantization. func TestQ4K_RoundTrip(t *testing.T) { patterns := qkTestPatterns() for i, expected := range patterns { q := quant.QuantizeQ4K(expected) if len(q) != 144 { t.Fatalf("unexpected q4k buffer size %d", len(q)) } block := (*tensor.BlockQ4_K)(unsafe.Pointer(&q[0])) var out [256]float32 tensor.DequantizeQ4_K(block, out[:]) mse := 0.0 maxDiff := 0.0 for j := 0; j < 256; j++ { diff := math.Abs(float64(out[j] - expected[j])) mse += diff * diff if diff > maxDiff { maxDiff = diff } } mse /= 256.0 desc := "Unknown" tolerance := 0.05 switch i { case 0: desc = "Gradient" case 1: desc = "Random Normal" case 2: desc = "Random Large" tolerance = 20.0 case 3: desc = "Sparse" } t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f", i, desc, mse, maxDiff) if mse > tolerance { t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance) } } } // TestQ8K_RoundTrip quantizes patterns with Go impl and checks dequantization. func TestQ8K_RoundTrip(t *testing.T) { patterns := qkTestPatterns() for i, expected := range patterns { q := quant.QuantizeQ8K(expected) if len(q) != 292 { t.Fatalf("unexpected q8k buffer size %d", len(q)) } block := (*tensor.BlockQ8_K)(unsafe.Pointer(&q[0])) var out [256]float32 tensor.DequantizeQ8_K(block, out[:]) mse := 0.0 maxDiff := 0.0 for j := 0; j < 256; j++ { diff := math.Abs(float64(out[j] - expected[j])) mse += diff * diff if diff > maxDiff { maxDiff = diff } } mse /= 256.0 desc := "Unknown" tolerance := 0.01 switch i { case 0: desc = "Gradient" case 1: desc = "Random Normal" case 2: desc = "Random Large" tolerance = 1.0 case 3: desc = "Sparse" } t.Logf("Block %d (%s): MSE=%f, MaxDiff=%f", i, desc, mse, maxDiff) if mse > tolerance { t.Errorf("Block %d (%s) MSE too high: %f (tol=%f)", i, desc, mse, tolerance) } } }