convert_plugin.go 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. package qwen3
  2. import (
  3. "strings"
  4. "makarna/pkg/convert"
  5. "makarna/pkg/quant"
  6. )
  7. type convertPlugin struct{}
  8. func (convertPlugin) Apply(spec *convert.Spec) {
  9. prev := spec.ResolveQuant
  10. mixRules := map[quant.QuantType][]quant.Rule{
  11. quant.TypeQ4K: {
  12. {Pattern: "*embed_tokens*", QuantType: quant.TypeQ6K},
  13. {Pattern: "*norm*", QuantType: quant.TypeF32},
  14. },
  15. quant.TypeQ3K: {
  16. {Pattern: "*embed_tokens*", QuantType: quant.TypeQ6K},
  17. {Pattern: "*lm_head*", QuantType: quant.TypeQ6K},
  18. {Pattern: "*norm*", QuantType: quant.TypeF32},
  19. },
  20. quant.TypeQ6K: {
  21. {Pattern: "*embed_tokens*", QuantType: quant.TypeQ8K},
  22. },
  23. quant.TypeQ2K: {
  24. {Pattern: "*embed_tokens*", QuantType: quant.TypeQ6K},
  25. {Pattern: "*lm_head*", QuantType: quant.TypeQ6K},
  26. {Pattern: "*v_proj*", QuantType: quant.TypeQ3K},
  27. {Pattern: "*o_proj*", QuantType: quant.TypeQ3K},
  28. {Pattern: "*down_proj*", QuantType: quant.TypeQ3K},
  29. },
  30. }
  31. spec.ResolveQuant = func(name string, baseQuant quant.QuantType) quant.QuantType {
  32. qt := baseQuant
  33. if prev != nil {
  34. qt = prev(name, baseQuant)
  35. }
  36. if spec.MixMode {
  37. if rules, ok := mixRules[baseQuant]; ok {
  38. qt = quant.ApplyRules(name, baseQuant, rules)
  39. }
  40. }
  41. lname := strings.ToLower(name)
  42. // Keep norms in F32 when requested (safe even if tensor is not quantizable).
  43. if strings.Contains(lname, "norm") {
  44. return quant.TypeF32
  45. }
  46. // Enforce higher quality for embeddings and head even when mix mode is off.
  47. // This is a model-specific policy.
  48. if strings.Contains(lname, "embed_tokens") {
  49. switch baseQuant {
  50. case quant.TypeQ6K:
  51. return quant.TypeQ8K
  52. case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K:
  53. return quant.TypeQ6K
  54. }
  55. }
  56. if strings.Contains(lname, "lm_head") {
  57. switch baseQuant {
  58. case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K:
  59. return quant.TypeQ6K
  60. }
  61. }
  62. return qt
  63. }
  64. }
  65. func init() {
  66. convert.Register("qwen3", convertPlugin{})
  67. }