// Command quantize converts F32/F16 .mak files to quantized versions // // Usage: // // quantize input.mak output.mak q4_k // quantize input.mak output.mak q4_k --mix (enables smart mix quantization) package main import ( "encoding/binary" "flag" "fmt" "math" "os" "time" "makarna/pkg/convert" "makarna/pkg/loader" "makarna/pkg/quant" _ "makarna/pkg/model/models" ) func main() { // Parse flags mixMode := flag.Bool("mix", false, "Enable smart mix quantization (uses architecture-specific rules)") flag.Parse() args := flag.Args() if len(args) < 3 { fmt.Println("Usage: quantize [--mix]") fmt.Println("Quant types: q2_k, q3_k, q4_k, q5_k, q6_k, q8_k") fmt.Println("Flags:") fmt.Println(" --mix Enable smart mix quantization (uses architecture-specific rules)") os.Exit(1) } inputPath := args[0] outputPath := args[1] quantTypeStr := args[2] baseQuant := quant.QuantType(quantTypeStr) // Validate quant type switch baseQuant { case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K, quant.TypeQ5K, quant.TypeQ6K, quant.TypeQ8K: // OK default: fmt.Printf("Unknown quant type: %s\n", quantTypeStr) os.Exit(1) } fmt.Printf("Loading %s...\n", inputPath) startTime := time.Now() model, err := loader.Load(inputPath) if err != nil { fmt.Printf("Error loading model: %v\n", err) os.Exit(1) } defer model.Close() architecture := model.Metadata.ModelConfig.Architecture fmt.Printf("Loaded in %v\n", time.Since(startTime)) fmt.Printf("Architecture: %s\n", architecture) // Build a policy spec (model plugin may override mix behavior) tieWordEmbeddings := false if tie, ok := model.Metadata.ModelConfig.Params["tie_word_embeddings"].(bool); ok { tieWordEmbeddings = tie } spec := convert.NewSpec(architecture, tieWordEmbeddings, baseQuant, *mixMode) // Create output writer writer, err := loader.NewWriter(outputPath) if err != nil { fmt.Printf("Error creating output: %v\n", err) os.Exit(1) } writer.SetModelConfig(model.Metadata.ModelConfig) // Copy tokenizer if present tokData, err := model.GetTokenizerData() if err == nil && len(tokData) > 0 { writer.AddTokenizer(tokData) fmt.Println("Copying embedded tokenizer...") } // Process tensors totalTensors := len(model.Metadata.Tensors) stats := make(map[quant.QuantType]int) skipped := 0 fmt.Printf("\nQuantizing %d tensors...\n\n", totalTensors) for name, info := range model.Metadata.Tensors { data, err := model.GetTensorData(name) if err != nil { fmt.Printf("Error reading tensor %s: %v\n", name, err) continue } // Determine if quantizable nDims := len(info.Shape) isQuantizable := nDims >= 2 && info.DType == loader.F32 // Check divisibility by 256 if isQuantizable && info.Shape[len(info.Shape)-1]%256 != 0 { isQuantizable = false } var outData []byte var outDType loader.DType if isQuantizable { // Resolve quant type (with mix mode if enabled) tensorQuant := baseQuant tensorQuant = spec.ResolveQuant(name, baseQuant) // Handle F32 (keep as-is) if tensorQuant == quant.TypeF32 || tensorQuant == quant.TypeF16 { outData = data outDType = tensorQuant.ToDType() stats[tensorQuant]++ fmt.Printf(" %s: %v [%s] (preserved)\n", name, info.Shape, tensorQuant) } else { // Convert bytes to float32 floats := bytesToFloat32(data) start := time.Now() switch tensorQuant { case quant.TypeQ8K: outData = quant.QuantizeQ8K(floats) case quant.TypeQ5K: outData = quant.QuantizeQ5K(floats) case quant.TypeQ6K: outData = quant.QuantizeQ6K(floats) case quant.TypeQ4K: outData = quant.QuantizeQ4K(floats) case quant.TypeQ3K: outData = quant.QuantizeQ3K(floats) case quant.TypeQ2K: outData = quant.QuantizeQ2K(floats) default: outData = quant.QuantizeQ4K(floats) tensorQuant = quant.TypeQ4K } elapsed := time.Since(start) outDType = tensorQuant.ToDType() stats[tensorQuant]++ ratio := float64(len(data)) / float64(len(outData)) // Show mix info if different from base mixInfo := "" if *mixMode && tensorQuant != baseQuant { mixInfo = fmt.Sprintf(" (mix: %s→%s)", baseQuant, tensorQuant) } fmt.Printf(" %s: %v → %s (%.2fx, %v)%s\n", name, info.Shape, tensorQuant, ratio, elapsed, mixInfo) } } else { // Keep as-is outData = data outDType = info.DType skipped++ } // Convert shape to uint64 shape := make([]uint64, len(info.Shape)) for i, s := range info.Shape { shape[i] = s } if err := writer.AddTensor(name, outDType, shape, outData); err != nil { fmt.Printf("Error writing tensor %s: %v\n", name, err) } } if err := writer.Close(); err != nil { fmt.Printf("Error closing output: %v\n", err) os.Exit(1) } // Get file sizes inStat, _ := os.Stat(inputPath) outStat, _ := os.Stat(outputPath) fmt.Printf("\nāœ“ Done!\n") fmt.Printf(" Quantization breakdown:\n") for qt, count := range stats { fmt.Printf(" %s: %d tensors\n", qt, count) } fmt.Printf(" Skipped: %d tensors\n", skipped) fmt.Printf(" Input size: %.2f MB\n", float64(inStat.Size())/(1024*1024)) fmt.Printf(" Output size: %.2f MB\n", float64(outStat.Size())/(1024*1024)) fmt.Printf(" Compression: %.2fx\n", float64(inStat.Size())/float64(outStat.Size())) fmt.Printf(" Total time: %v\n", time.Since(startTime)) } func bytesToFloat32(data []byte) []float32 { n := len(data) / 4 result := make([]float32, n) for i := 0; i < n; i++ { bits := binary.LittleEndian.Uint32(data[i*4 : i*4+4]) result[i] = math.Float32frombits(bits) } return result }