| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- // Command quantize converts F32/F16 .mak files to quantized versions
- //
- // Usage:
- //
- // quantize input.mak output.mak q4_k
- // quantize input.mak output.mak q4_k --mix (enables smart mix quantization)
- package main
- import (
- "encoding/binary"
- "flag"
- "fmt"
- "math"
- "os"
- "time"
- "makarna/pkg/convert"
- "makarna/pkg/loader"
- "makarna/pkg/quant"
- _ "makarna/pkg/model/models"
- )
- func main() {
- // Parse flags
- mixMode := flag.Bool("mix", false, "Enable smart mix quantization (uses architecture-specific rules)")
- flag.Parse()
- args := flag.Args()
- if len(args) < 3 {
- fmt.Println("Usage: quantize <input.mak> <output.mak> <quant_type> [--mix]")
- fmt.Println("Quant types: q2_k, q3_k, q4_k, q5_k, q6_k, q8_k")
- fmt.Println("Flags:")
- fmt.Println(" --mix Enable smart mix quantization (uses architecture-specific rules)")
- os.Exit(1)
- }
- inputPath := args[0]
- outputPath := args[1]
- quantTypeStr := args[2]
- baseQuant := quant.QuantType(quantTypeStr)
-
- // Validate quant type
- switch baseQuant {
- case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K, quant.TypeQ5K, quant.TypeQ6K, quant.TypeQ8K:
- // OK
- default:
- fmt.Printf("Unknown quant type: %s\n", quantTypeStr)
- os.Exit(1)
- }
- fmt.Printf("Loading %s...\n", inputPath)
- startTime := time.Now()
- model, err := loader.Load(inputPath)
- if err != nil {
- fmt.Printf("Error loading model: %v\n", err)
- os.Exit(1)
- }
- defer model.Close()
- architecture := model.Metadata.ModelConfig.Architecture
- fmt.Printf("Loaded in %v\n", time.Since(startTime))
- fmt.Printf("Architecture: %s\n", architecture)
-
- // Build a policy spec (model plugin may override mix behavior)
- tieWordEmbeddings := false
- if tie, ok := model.Metadata.ModelConfig.Params["tie_word_embeddings"].(bool); ok {
- tieWordEmbeddings = tie
- }
- spec := convert.NewSpec(architecture, tieWordEmbeddings, baseQuant, *mixMode)
- // Create output writer
- writer, err := loader.NewWriter(outputPath)
- if err != nil {
- fmt.Printf("Error creating output: %v\n", err)
- os.Exit(1)
- }
- writer.SetModelConfig(model.Metadata.ModelConfig)
- // Copy tokenizer if present
- tokData, err := model.GetTokenizerData()
- if err == nil && len(tokData) > 0 {
- writer.AddTokenizer(tokData)
- fmt.Println("Copying embedded tokenizer...")
- }
- // Process tensors
- totalTensors := len(model.Metadata.Tensors)
- stats := make(map[quant.QuantType]int)
- skipped := 0
- fmt.Printf("\nQuantizing %d tensors...\n\n", totalTensors)
- for name, info := range model.Metadata.Tensors {
- data, err := model.GetTensorData(name)
- if err != nil {
- fmt.Printf("Error reading tensor %s: %v\n", name, err)
- continue
- }
- // Determine if quantizable
- nDims := len(info.Shape)
- isQuantizable := nDims >= 2 && info.DType == loader.F32
-
- // Check divisibility by 256
- if isQuantizable && info.Shape[len(info.Shape)-1]%256 != 0 {
- isQuantizable = false
- }
- var outData []byte
- var outDType loader.DType
-
- if isQuantizable {
- // Resolve quant type (with mix mode if enabled)
- tensorQuant := baseQuant
- tensorQuant = spec.ResolveQuant(name, baseQuant)
-
- // Handle F32 (keep as-is)
- if tensorQuant == quant.TypeF32 || tensorQuant == quant.TypeF16 {
- outData = data
- outDType = tensorQuant.ToDType()
- stats[tensorQuant]++
- fmt.Printf(" %s: %v [%s] (preserved)\n", name, info.Shape, tensorQuant)
- } else {
- // Convert bytes to float32
- floats := bytesToFloat32(data)
-
- start := time.Now()
-
- switch tensorQuant {
- case quant.TypeQ8K:
- outData = quant.QuantizeQ8K(floats)
- case quant.TypeQ5K:
- outData = quant.QuantizeQ5K(floats)
- case quant.TypeQ6K:
- outData = quant.QuantizeQ6K(floats)
- case quant.TypeQ4K:
- outData = quant.QuantizeQ4K(floats)
- case quant.TypeQ3K:
- outData = quant.QuantizeQ3K(floats)
- case quant.TypeQ2K:
- outData = quant.QuantizeQ2K(floats)
- default:
- outData = quant.QuantizeQ4K(floats)
- tensorQuant = quant.TypeQ4K
- }
-
- elapsed := time.Since(start)
- outDType = tensorQuant.ToDType()
- stats[tensorQuant]++
-
- ratio := float64(len(data)) / float64(len(outData))
-
- // Show mix info if different from base
- mixInfo := ""
- if *mixMode && tensorQuant != baseQuant {
- mixInfo = fmt.Sprintf(" (mix: %s→%s)", baseQuant, tensorQuant)
- }
-
- fmt.Printf(" %s: %v → %s (%.2fx, %v)%s\n",
- name, info.Shape, tensorQuant, ratio, elapsed, mixInfo)
- }
- } else {
- // Keep as-is
- outData = data
- outDType = info.DType
- skipped++
- }
- // Convert shape to uint64
- shape := make([]uint64, len(info.Shape))
- for i, s := range info.Shape {
- shape[i] = s
- }
- if err := writer.AddTensor(name, outDType, shape, outData); err != nil {
- fmt.Printf("Error writing tensor %s: %v\n", name, err)
- }
- }
- if err := writer.Close(); err != nil {
- fmt.Printf("Error closing output: %v\n", err)
- os.Exit(1)
- }
- // Get file sizes
- inStat, _ := os.Stat(inputPath)
- outStat, _ := os.Stat(outputPath)
- fmt.Printf("\n✓ Done!\n")
- fmt.Printf(" Quantization breakdown:\n")
- for qt, count := range stats {
- fmt.Printf(" %s: %d tensors\n", qt, count)
- }
- fmt.Printf(" Skipped: %d tensors\n", skipped)
- fmt.Printf(" Input size: %.2f MB\n", float64(inStat.Size())/(1024*1024))
- fmt.Printf(" Output size: %.2f MB\n", float64(outStat.Size())/(1024*1024))
- fmt.Printf(" Compression: %.2fx\n", float64(inStat.Size())/float64(outStat.Size()))
- fmt.Printf(" Total time: %v\n", time.Since(startTime))
- }
- func bytesToFloat32(data []byte) []float32 {
- n := len(data) / 4
- result := make([]float32, n)
- for i := 0; i < n; i++ {
- bits := binary.LittleEndian.Uint32(data[i*4 : i*4+4])
- result[i] = math.Float32frombits(bits)
- }
- return result
- }
|