| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- package kimi_linear
- import (
- "strings"
- "makarna/pkg/convert"
- "makarna/pkg/quant"
- )
- type convertPlugin struct{}
- func (convertPlugin) Apply(spec *convert.Spec) {
- prev := spec.ResolveQuant
- mixRules := map[quant.QuantType][]quant.Rule{
- quant.TypeQ4K: {
- {Pattern: "*embed_tokens*", QuantType: quant.TypeQ6K},
- {Pattern: "*lm_head*", QuantType: quant.TypeQ6K},
- {Pattern: "*norm*", QuantType: quant.TypeF32},
- {Pattern: "*block_sparse_moe.gate*", QuantType: quant.TypeQ6K},
- },
- quant.TypeQ3K: {
- {Pattern: "*embed_tokens*", QuantType: quant.TypeQ6K},
- {Pattern: "*lm_head*", QuantType: quant.TypeQ6K},
- {Pattern: "*norm*", QuantType: quant.TypeF32},
- {Pattern: "*block_sparse_moe.gate*", QuantType: quant.TypeQ6K},
- },
- quant.TypeQ2K: {
- {Pattern: "*embed_tokens*", QuantType: quant.TypeQ6K},
- {Pattern: "*lm_head*", QuantType: quant.TypeQ6K},
- {Pattern: "*norm*", QuantType: quant.TypeF32},
- {Pattern: "*block_sparse_moe.gate*", QuantType: quant.TypeQ6K},
- },
- }
- spec.ResolveQuant = func(name string, baseQuant quant.QuantType) quant.QuantType {
- qt := baseQuant
- if prev != nil {
- qt = prev(name, baseQuant)
- }
- if spec.MixMode {
- if rules, ok := mixRules[baseQuant]; ok {
- qt = quant.ApplyRules(name, baseQuant, rules)
- }
- }
- lname := strings.ToLower(name)
- if strings.Contains(lname, "norm") {
- return quant.TypeF32
- }
- if strings.Contains(lname, "embed_tokens") {
- switch baseQuant {
- case quant.TypeQ6K:
- return quant.TypeQ8K
- case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K:
- return quant.TypeQ6K
- }
- }
- if strings.Contains(lname, "lm_head") {
- switch baseQuant {
- case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K:
- return quant.TypeQ6K
- }
- }
- if strings.Contains(lname, "block_sparse_moe.gate") {
- switch baseQuant {
- case quant.TypeQ2K, quant.TypeQ3K, quant.TypeQ4K:
- return quant.TypeQ6K
- }
- }
- return qt
- }
- }
- func init() {
- convert.Register("kimilinear", convertPlugin{})
- }
|