main.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. // Command quantize converts F32/F16 .mak files to quantized versions
  2. //
  3. // Usage:
  4. //
  5. // quantize input.mak output.mak q4_k
  6. // quantize input.mak output.mak q4_k --mix (enables smart mix quantization)
  7. package main
  8. import (
  9. "encoding/binary"
  10. "flag"
  11. "fmt"
  12. "math"
  13. "os"
  14. "time"
  15. "makarna/pkg/convert"
  16. "makarna/pkg/loader"
  17. "makarna/pkg/quant"
  18. _ "makarna/pkg/model/models"
  19. )
  20. func main() {
  21. // Parse flags
  22. mixMode := flag.Bool("mix", false, "Enable smart mix quantization (uses architecture-specific rules)")
  23. flag.Parse()
  24. args := flag.Args()
  25. if len(args) < 3 {
  26. fmt.Println("Usage: quantize <input.mak> <output.mak> <quant_type> [--mix]")
  27. fmt.Println("Quant types: q2_k, q3_k, q4_k, q5_k, q6_k, q8_k")
  28. fmt.Println("Flags:")
  29. fmt.Println(" --mix Enable smart mix quantization (uses architecture-specific rules)")
  30. os.Exit(1)
  31. }
  32. inputPath := args[0]
  33. outputPath := args[1]
  34. quantTypeStr := args[2]
  35. baseQuant := quant.QuantType(quantTypeStr)
  36. // Validate quant type
  37. switch baseQuant {
  38. case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K, quant.TypeQ5K, quant.TypeQ6K, quant.TypeQ8K:
  39. // OK
  40. default:
  41. fmt.Printf("Unknown quant type: %s\n", quantTypeStr)
  42. os.Exit(1)
  43. }
  44. fmt.Printf("Loading %s...\n", inputPath)
  45. startTime := time.Now()
  46. model, err := loader.Load(inputPath)
  47. if err != nil {
  48. fmt.Printf("Error loading model: %v\n", err)
  49. os.Exit(1)
  50. }
  51. defer model.Close()
  52. architecture := model.Metadata.ModelConfig.Architecture
  53. fmt.Printf("Loaded in %v\n", time.Since(startTime))
  54. fmt.Printf("Architecture: %s\n", architecture)
  55. // Build a policy spec (model plugin may override mix behavior)
  56. tieWordEmbeddings := false
  57. if tie, ok := model.Metadata.ModelConfig.Params["tie_word_embeddings"].(bool); ok {
  58. tieWordEmbeddings = tie
  59. }
  60. spec := convert.NewSpec(architecture, tieWordEmbeddings, baseQuant, *mixMode)
  61. // Create output writer
  62. writer, err := loader.NewWriter(outputPath)
  63. if err != nil {
  64. fmt.Printf("Error creating output: %v\n", err)
  65. os.Exit(1)
  66. }
  67. writer.SetModelConfig(model.Metadata.ModelConfig)
  68. // Copy tokenizer if present
  69. tokData, err := model.GetTokenizerData()
  70. if err == nil && len(tokData) > 0 {
  71. writer.AddTokenizer(tokData)
  72. fmt.Println("Copying embedded tokenizer...")
  73. }
  74. // Process tensors
  75. totalTensors := len(model.Metadata.Tensors)
  76. stats := make(map[quant.QuantType]int)
  77. skipped := 0
  78. fmt.Printf("\nQuantizing %d tensors...\n\n", totalTensors)
  79. for name, info := range model.Metadata.Tensors {
  80. data, err := model.GetTensorData(name)
  81. if err != nil {
  82. fmt.Printf("Error reading tensor %s: %v\n", name, err)
  83. continue
  84. }
  85. // Determine if quantizable
  86. nDims := len(info.Shape)
  87. isQuantizable := nDims >= 2 && info.DType == loader.F32
  88. // Check divisibility by 256
  89. if isQuantizable && info.Shape[len(info.Shape)-1]%256 != 0 {
  90. isQuantizable = false
  91. }
  92. var outData []byte
  93. var outDType loader.DType
  94. if isQuantizable {
  95. // Resolve quant type (with mix mode if enabled)
  96. tensorQuant := baseQuant
  97. tensorQuant = spec.ResolveQuant(name, baseQuant)
  98. // Handle F32 (keep as-is)
  99. if tensorQuant == quant.TypeF32 || tensorQuant == quant.TypeF16 {
  100. outData = data
  101. outDType = tensorQuant.ToDType()
  102. stats[tensorQuant]++
  103. fmt.Printf(" %s: %v [%s] (preserved)\n", name, info.Shape, tensorQuant)
  104. } else {
  105. // Convert bytes to float32
  106. floats := bytesToFloat32(data)
  107. start := time.Now()
  108. switch tensorQuant {
  109. case quant.TypeQ8K:
  110. outData = quant.QuantizeQ8K(floats)
  111. case quant.TypeQ5K:
  112. outData = quant.QuantizeQ5K(floats)
  113. case quant.TypeQ6K:
  114. outData = quant.QuantizeQ6K(floats)
  115. case quant.TypeQ4K:
  116. outData = quant.QuantizeQ4K(floats)
  117. case quant.TypeQ3K:
  118. outData = quant.QuantizeQ3K(floats)
  119. case quant.TypeQ2K:
  120. outData = quant.QuantizeQ2K(floats)
  121. default:
  122. outData = quant.QuantizeQ4K(floats)
  123. tensorQuant = quant.TypeQ4K
  124. }
  125. elapsed := time.Since(start)
  126. outDType = tensorQuant.ToDType()
  127. stats[tensorQuant]++
  128. ratio := float64(len(data)) / float64(len(outData))
  129. // Show mix info if different from base
  130. mixInfo := ""
  131. if *mixMode && tensorQuant != baseQuant {
  132. mixInfo = fmt.Sprintf(" (mix: %s→%s)", baseQuant, tensorQuant)
  133. }
  134. fmt.Printf(" %s: %v → %s (%.2fx, %v)%s\n",
  135. name, info.Shape, tensorQuant, ratio, elapsed, mixInfo)
  136. }
  137. } else {
  138. // Keep as-is
  139. outData = data
  140. outDType = info.DType
  141. skipped++
  142. }
  143. // Convert shape to uint64
  144. shape := make([]uint64, len(info.Shape))
  145. for i, s := range info.Shape {
  146. shape[i] = s
  147. }
  148. if err := writer.AddTensor(name, outDType, shape, outData); err != nil {
  149. fmt.Printf("Error writing tensor %s: %v\n", name, err)
  150. }
  151. }
  152. if err := writer.Close(); err != nil {
  153. fmt.Printf("Error closing output: %v\n", err)
  154. os.Exit(1)
  155. }
  156. // Get file sizes
  157. inStat, _ := os.Stat(inputPath)
  158. outStat, _ := os.Stat(outputPath)
  159. fmt.Printf("\n✓ Done!\n")
  160. fmt.Printf(" Quantization breakdown:\n")
  161. for qt, count := range stats {
  162. fmt.Printf(" %s: %d tensors\n", qt, count)
  163. }
  164. fmt.Printf(" Skipped: %d tensors\n", skipped)
  165. fmt.Printf(" Input size: %.2f MB\n", float64(inStat.Size())/(1024*1024))
  166. fmt.Printf(" Output size: %.2f MB\n", float64(outStat.Size())/(1024*1024))
  167. fmt.Printf(" Compression: %.2fx\n", float64(inStat.Size())/float64(outStat.Size()))
  168. fmt.Printf(" Total time: %v\n", time.Since(startTime))
  169. }
  170. func bytesToFloat32(data []byte) []float32 {
  171. n := len(data) / 4
  172. result := make([]float32, n)
  173. for i := 0; i < n; i++ {
  174. bits := binary.LittleEndian.Uint32(data[i*4 : i*4+4])
  175. result[i] = math.Float32frombits(bits)
  176. }
  177. return result
  178. }