| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- package kimi_linear
- import (
- "fmt"
- "makarna/pkg/model"
- )
- type LinearAttnConfig struct {
- NumHeads int
- HeadDim int
- ShortConvKernel int
- KDALayers []int
- FullAttnLayers []int
- }
- type MLAConfig struct {
- NumHeads int
- KVLoraRank int
- QKNopeHeadDim int
- QKRopeHeadDim int
- VHeadDim int
- }
- type MoEConfig struct {
- NumExperts int
- TopK int
- IntermediateSize int
- NumSharedExperts int
- FirstKDenseReplace int
- LayerFreq int
- Renormalize bool
- RouterActivationFunc string
- RoutedScalingFactor float32
- UseGroupedTopK bool
- NumExpertGroup int
- TopKGroup int
- }
- func parseMoEConfig(cfg *model.Config) (MoEConfig, error) {
- var out MoEConfig
- if cfg == nil || cfg.Params == nil {
- return out, fmt.Errorf("kimi_linear: missing config params")
- }
- if v, ok := cfg.Params["num_experts"].(float64); ok {
- out.NumExperts = int(v)
- }
- if v, ok := cfg.Params["num_experts_per_token"].(float64); ok {
- out.TopK = int(v)
- }
- if v, ok := cfg.Params["moe_intermediate_size"].(float64); ok {
- out.IntermediateSize = int(v)
- }
- if v, ok := cfg.Params["num_shared_experts"].(float64); ok {
- out.NumSharedExperts = int(v)
- }
- if v, ok := cfg.Params["first_k_dense_replace"].(float64); ok {
- out.FirstKDenseReplace = int(v)
- }
- if v, ok := cfg.Params["moe_layer_freq"].(float64); ok {
- out.LayerFreq = int(v)
- }
- if v, ok := cfg.Params["moe_renormalize"].(bool); ok {
- out.Renormalize = v
- }
- if v, ok := cfg.Params["moe_router_activation_func"].(string); ok {
- out.RouterActivationFunc = v
- }
- if v, ok := cfg.Params["routed_scaling_factor"].(float64); ok {
- out.RoutedScalingFactor = float32(v)
- }
- if v, ok := cfg.Params["use_grouped_topk"].(bool); ok {
- out.UseGroupedTopK = v
- }
- if v, ok := cfg.Params["num_expert_group"].(float64); ok {
- out.NumExpertGroup = int(v)
- }
- if v, ok := cfg.Params["topk_group"].(float64); ok {
- out.TopKGroup = int(v)
- }
- if out.LayerFreq == 0 {
- out.LayerFreq = 1
- }
- if out.TopK == 0 {
- out.TopK = 1
- }
- if out.RoutedScalingFactor == 0 {
- out.RoutedScalingFactor = 1
- }
- if out.NumExpertGroup == 0 {
- out.NumExpertGroup = 1
- }
- if out.TopKGroup == 0 {
- out.TopKGroup = 1
- }
- return out, nil
- }
- func parseMLAConfig(cfg *model.Config) (MLAConfig, error) {
- var out MLAConfig
- if cfg == nil || cfg.Params == nil {
- return out, fmt.Errorf("kimi_linear: missing config params")
- }
- if v, ok := cfg.Params["num_attention_heads"].(float64); ok {
- out.NumHeads = int(v)
- }
- if v, ok := cfg.Params["kv_lora_rank"].(float64); ok {
- out.KVLoraRank = int(v)
- }
- if v, ok := cfg.Params["qk_nope_head_dim"].(float64); ok {
- out.QKNopeHeadDim = int(v)
- }
- if v, ok := cfg.Params["qk_rope_head_dim"].(float64); ok {
- out.QKRopeHeadDim = int(v)
- }
- if v, ok := cfg.Params["v_head_dim"].(float64); ok {
- out.VHeadDim = int(v)
- }
- return out, nil
- }
- func parseLinearAttnConfig(cfg *model.Config) (LinearAttnConfig, error) {
- var out LinearAttnConfig
- if cfg == nil || cfg.Params == nil {
- return out, fmt.Errorf("kimi_linear: missing config params")
- }
- lacRaw, ok := cfg.Params["linear_attn_config"]
- if !ok || lacRaw == nil {
- return out, nil
- }
- lac, ok := lacRaw.(map[string]any)
- if !ok {
- return out, fmt.Errorf("kimi_linear: linear_attn_config has unexpected type %T", lacRaw)
- }
- if v, ok := lac["num_heads"].(float64); ok {
- out.NumHeads = int(v)
- }
- if v, ok := lac["head_dim"].(float64); ok {
- out.HeadDim = int(v)
- }
- if v, ok := lac["short_conv_kernel_size"].(float64); ok {
- out.ShortConvKernel = int(v)
- }
- out.KDALayers = parseIntList(lac["kda_layers"])
- out.FullAttnLayers = parseIntList(lac["full_attn_layers"])
- return out, nil
- }
- func parseIntList(v any) []int {
- arr, ok := v.([]any)
- if !ok {
- if f, ok := v.([]float64); ok {
- out := make([]int, len(f))
- for i := range f {
- out[i] = int(f[i])
- }
- return out
- }
- return nil
- }
- out := make([]int, 0, len(arr))
- for _, it := range arr {
- switch x := it.(type) {
- case float64:
- out = append(out, int(x))
- case int:
- out = append(out, x)
- }
- }
- return out
- }
|