convert_plugin.go 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. package kimi_linear
  2. import (
  3. "strings"
  4. "makarna/pkg/convert"
  5. "makarna/pkg/quant"
  6. )
  7. type convertPlugin struct{}
  8. func (convertPlugin) Apply(spec *convert.Spec) {
  9. prev := spec.ResolveQuant
  10. mixRules := map[quant.QuantType][]quant.Rule{
  11. quant.TypeQ4K: {
  12. {Pattern: "*embed_tokens*", QuantType: quant.TypeQ6K},
  13. {Pattern: "*lm_head*", QuantType: quant.TypeQ6K},
  14. {Pattern: "*norm*", QuantType: quant.TypeF32},
  15. {Pattern: "*block_sparse_moe.gate*", QuantType: quant.TypeQ6K},
  16. },
  17. quant.TypeQ3K: {
  18. {Pattern: "*embed_tokens*", QuantType: quant.TypeQ6K},
  19. {Pattern: "*lm_head*", QuantType: quant.TypeQ6K},
  20. {Pattern: "*norm*", QuantType: quant.TypeF32},
  21. {Pattern: "*block_sparse_moe.gate*", QuantType: quant.TypeQ6K},
  22. },
  23. quant.TypeQ2K: {
  24. {Pattern: "*embed_tokens*", QuantType: quant.TypeQ6K},
  25. {Pattern: "*lm_head*", QuantType: quant.TypeQ6K},
  26. {Pattern: "*norm*", QuantType: quant.TypeF32},
  27. {Pattern: "*block_sparse_moe.gate*", QuantType: quant.TypeQ6K},
  28. },
  29. }
  30. spec.ResolveQuant = func(name string, baseQuant quant.QuantType) quant.QuantType {
  31. qt := baseQuant
  32. if prev != nil {
  33. qt = prev(name, baseQuant)
  34. }
  35. if spec.MixMode {
  36. if rules, ok := mixRules[baseQuant]; ok {
  37. qt = quant.ApplyRules(name, baseQuant, rules)
  38. }
  39. }
  40. lname := strings.ToLower(name)
  41. if strings.Contains(lname, "norm") {
  42. return quant.TypeF32
  43. }
  44. if strings.Contains(lname, "embed_tokens") {
  45. switch baseQuant {
  46. case quant.TypeQ6K:
  47. return quant.TypeQ8K
  48. case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K:
  49. return quant.TypeQ6K
  50. }
  51. }
  52. if strings.Contains(lname, "lm_head") {
  53. switch baseQuant {
  54. case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K:
  55. return quant.TypeQ6K
  56. }
  57. }
  58. if strings.Contains(lname, "block_sparse_moe.gate") {
  59. switch baseQuant {
  60. case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K:
  61. return quant.TypeQ6K
  62. }
  63. }
  64. return qt
  65. }
  66. }
  67. func init() {
  68. convert.Register("kimilinear", convertPlugin{})
  69. }