config.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package kimi_linear
  2. import (
  3. "fmt"
  4. "makarna/pkg/model"
  5. )
  6. type LinearAttnConfig struct {
  7. NumHeads int
  8. HeadDim int
  9. ShortConvKernel int
  10. KDALayers []int
  11. FullAttnLayers []int
  12. }
  13. type MLAConfig struct {
  14. NumHeads int
  15. KVLoraRank int
  16. QKNopeHeadDim int
  17. QKRopeHeadDim int
  18. VHeadDim int
  19. }
  20. type MoEConfig struct {
  21. NumExperts int
  22. TopK int
  23. IntermediateSize int
  24. NumSharedExperts int
  25. FirstKDenseReplace int
  26. LayerFreq int
  27. Renormalize bool
  28. RouterActivationFunc string
  29. RoutedScalingFactor float32
  30. UseGroupedTopK bool
  31. NumExpertGroup int
  32. TopKGroup int
  33. }
  34. func parseMoEConfig(cfg *model.Config) (MoEConfig, error) {
  35. var out MoEConfig
  36. if cfg == nil || cfg.Params == nil {
  37. return out, fmt.Errorf("kimi_linear: missing config params")
  38. }
  39. if v, ok := cfg.Params["num_experts"].(float64); ok {
  40. out.NumExperts = int(v)
  41. }
  42. if v, ok := cfg.Params["num_experts_per_token"].(float64); ok {
  43. out.TopK = int(v)
  44. }
  45. if v, ok := cfg.Params["moe_intermediate_size"].(float64); ok {
  46. out.IntermediateSize = int(v)
  47. }
  48. if v, ok := cfg.Params["num_shared_experts"].(float64); ok {
  49. out.NumSharedExperts = int(v)
  50. }
  51. if v, ok := cfg.Params["first_k_dense_replace"].(float64); ok {
  52. out.FirstKDenseReplace = int(v)
  53. }
  54. if v, ok := cfg.Params["moe_layer_freq"].(float64); ok {
  55. out.LayerFreq = int(v)
  56. }
  57. if v, ok := cfg.Params["moe_renormalize"].(bool); ok {
  58. out.Renormalize = v
  59. }
  60. if v, ok := cfg.Params["moe_router_activation_func"].(string); ok {
  61. out.RouterActivationFunc = v
  62. }
  63. if v, ok := cfg.Params["routed_scaling_factor"].(float64); ok {
  64. out.RoutedScalingFactor = float32(v)
  65. }
  66. if v, ok := cfg.Params["use_grouped_topk"].(bool); ok {
  67. out.UseGroupedTopK = v
  68. }
  69. if v, ok := cfg.Params["num_expert_group"].(float64); ok {
  70. out.NumExpertGroup = int(v)
  71. }
  72. if v, ok := cfg.Params["topk_group"].(float64); ok {
  73. out.TopKGroup = int(v)
  74. }
  75. if out.LayerFreq == 0 {
  76. out.LayerFreq = 1
  77. }
  78. if out.TopK == 0 {
  79. out.TopK = 1
  80. }
  81. if out.RoutedScalingFactor == 0 {
  82. out.RoutedScalingFactor = 1
  83. }
  84. if out.NumExpertGroup == 0 {
  85. out.NumExpertGroup = 1
  86. }
  87. if out.TopKGroup == 0 {
  88. out.TopKGroup = 1
  89. }
  90. return out, nil
  91. }
  92. func parseMLAConfig(cfg *model.Config) (MLAConfig, error) {
  93. var out MLAConfig
  94. if cfg == nil || cfg.Params == nil {
  95. return out, fmt.Errorf("kimi_linear: missing config params")
  96. }
  97. if v, ok := cfg.Params["num_attention_heads"].(float64); ok {
  98. out.NumHeads = int(v)
  99. }
  100. if v, ok := cfg.Params["kv_lora_rank"].(float64); ok {
  101. out.KVLoraRank = int(v)
  102. }
  103. if v, ok := cfg.Params["qk_nope_head_dim"].(float64); ok {
  104. out.QKNopeHeadDim = int(v)
  105. }
  106. if v, ok := cfg.Params["qk_rope_head_dim"].(float64); ok {
  107. out.QKRopeHeadDim = int(v)
  108. }
  109. if v, ok := cfg.Params["v_head_dim"].(float64); ok {
  110. out.VHeadDim = int(v)
  111. }
  112. return out, nil
  113. }
  114. func parseLinearAttnConfig(cfg *model.Config) (LinearAttnConfig, error) {
  115. var out LinearAttnConfig
  116. if cfg == nil || cfg.Params == nil {
  117. return out, fmt.Errorf("kimi_linear: missing config params")
  118. }
  119. lacRaw, ok := cfg.Params["linear_attn_config"]
  120. if !ok || lacRaw == nil {
  121. return out, nil
  122. }
  123. lac, ok := lacRaw.(map[string]any)
  124. if !ok {
  125. return out, fmt.Errorf("kimi_linear: linear_attn_config has unexpected type %T", lacRaw)
  126. }
  127. if v, ok := lac["num_heads"].(float64); ok {
  128. out.NumHeads = int(v)
  129. }
  130. if v, ok := lac["head_dim"].(float64); ok {
  131. out.HeadDim = int(v)
  132. }
  133. if v, ok := lac["short_conv_kernel_size"].(float64); ok {
  134. out.ShortConvKernel = int(v)
  135. }
  136. out.KDALayers = parseIntList(lac["kda_layers"])
  137. out.FullAttnLayers = parseIntList(lac["full_attn_layers"])
  138. return out, nil
  139. }
  140. func parseIntList(v any) []int {
  141. arr, ok := v.([]any)
  142. if !ok {
  143. if f, ok := v.([]float64); ok {
  144. out := make([]int, len(f))
  145. for i := range f {
  146. out[i] = int(f[i])
  147. }
  148. return out
  149. }
  150. return nil
  151. }
  152. out := make([]int, 0, len(arr))
  153. for _, it := range arr {
  154. switch x := it.(type) {
  155. case float64:
  156. out = append(out, int(x))
  157. case int:
  158. out = append(out, x)
  159. }
  160. }
  161. return out
  162. }