1
0

moe.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. package nn
  2. import (
  3. "sort"
  4. )
  5. // MoEChoice represents a selected expert with its weight.
  6. type MoEChoice struct {
  7. Idx int
  8. Weight float32
  9. }
  10. // TopKIndices returns the indices of the top-k largest values in scores.
  11. func TopKIndices(scores []float32, k int) []int {
  12. if k <= 0 {
  13. k = 1
  14. }
  15. if k > len(scores) {
  16. k = len(scores)
  17. }
  18. idx := make([]int, len(scores))
  19. for i := range idx {
  20. idx[i] = i
  21. }
  22. sort.Slice(idx, func(i, j int) bool { return scores[idx[i]] > scores[idx[j]] })
  23. return idx[:k]
  24. }
  25. // SelectTopKExperts selects top-k experts from scores and returns their indices and weights.
  26. // If useOriginalWeights is provided, weights are taken from that slice instead of scores.
  27. func SelectTopKExperts(scores []float32, k int, useOriginalWeights []float32) []MoEChoice {
  28. if k <= 0 {
  29. k = 1
  30. }
  31. if k > len(scores) {
  32. k = len(scores)
  33. }
  34. type scored struct {
  35. idx int
  36. score float32
  37. }
  38. choices := make([]scored, len(scores))
  39. for i, s := range scores {
  40. choices[i] = scored{idx: i, score: s}
  41. }
  42. sort.Slice(choices, func(i, j int) bool { return choices[i].score > choices[j].score })
  43. result := make([]MoEChoice, k)
  44. for i := 0; i < k; i++ {
  45. idx := choices[i].idx
  46. w := scores[idx]
  47. if useOriginalWeights != nil {
  48. w = useOriginalWeights[idx]
  49. }
  50. result[i] = MoEChoice{Idx: idx, Weight: w}
  51. }
  52. return result
  53. }
  54. // GroupedTopKMask applies grouped top-k selection as in DeepSeek-V3/Kimi MoE.
  55. // Returns a masked score array where only experts from selected groups are non-zero.
  56. //
  57. // Parameters:
  58. // - scores: router scores with bias already added (for selection)
  59. // - numGroups: number of expert groups
  60. // - topKGroup: how many groups to keep
  61. //
  62. // Returns masked scores suitable for final top-k selection.
  63. func GroupedTopKMask(scores []float32, numGroups, topKGroup int) []float32 {
  64. if numGroups <= 0 {
  65. numGroups = 1
  66. }
  67. numExperts := len(scores)
  68. if numExperts%numGroups != 0 {
  69. return scores // fallback: no masking
  70. }
  71. perGroup := numExperts / numGroups
  72. // Compute group scores (sum of top-2 in each group)
  73. groupScores := make([]float32, numGroups)
  74. for gi := 0; gi < numGroups; gi++ {
  75. base := gi * perGroup
  76. seg := scores[base : base+perGroup]
  77. top2Idx := TopKIndices(seg, 2)
  78. s := float32(0)
  79. for _, id := range top2Idx {
  80. s += seg[id]
  81. }
  82. groupScores[gi] = s
  83. }
  84. // Select top-k groups
  85. keepGroups := TopKIndices(groupScores, topKGroup)
  86. keep := make([]bool, numGroups)
  87. for _, gi := range keepGroups {
  88. keep[gi] = true
  89. }
  90. // Create masked output
  91. masked := make([]float32, numExperts)
  92. for gi := 0; gi < numGroups; gi++ {
  93. base := gi * perGroup
  94. if keep[gi] {
  95. copy(masked[base:base+perGroup], scores[base:base+perGroup])
  96. }
  97. // else: zeros (already initialized)
  98. }
  99. return masked
  100. }
  101. // RenormalizeMoEWeights normalizes weights to sum to 1.
  102. func RenormalizeMoEWeights(choices []MoEChoice) {
  103. if len(choices) == 0 {
  104. return
  105. }
  106. sum := float32(0)
  107. for _, c := range choices {
  108. sum += c.Weight
  109. }
  110. if sum == 0 {
  111. return
  112. }
  113. inv := 1 / sum
  114. for i := range choices {
  115. choices[i].Weight *= inv
  116. }
  117. }
  118. // ScaleMoEWeights multiplies all weights by the given factor.
  119. func ScaleMoEWeights(choices []MoEChoice, factor float32) {
  120. for i := range choices {
  121. choices[i].Weight *= factor
  122. }
  123. }
  124. // MoERouterActivation applies activation function to router logits.
  125. func MoERouterActivation(logits []float32, activationFunc string) []float32 {
  126. scores := make([]float32, len(logits))
  127. switch activationFunc {
  128. case "sigmoid":
  129. for i, v := range logits {
  130. scores[i] = Sigmoid(v)
  131. }
  132. case "softmax":
  133. copy(scores, logits)
  134. SoftmaxInplaceSimple(scores)
  135. default:
  136. // Default to sigmoid
  137. for i, v := range logits {
  138. scores[i] = Sigmoid(v)
  139. }
  140. }
  141. return scores
  142. }
  143. // SoftmaxInplaceSimple applies softmax normalization in-place (simple scalar version).
  144. func SoftmaxInplaceSimple(data []float32) {
  145. if len(data) == 0 {
  146. return
  147. }
  148. maxVal := data[0]
  149. for _, v := range data[1:] {
  150. if v > maxVal {
  151. maxVal = v
  152. }
  153. }
  154. sum := float32(0)
  155. for i := range data {
  156. data[i] = Exp(data[i] - maxVal)
  157. sum += data[i]
  158. }
  159. if sum == 0 {
  160. return
  161. }
  162. inv := 1 / sum
  163. for i := range data {
  164. data[i] *= inv
  165. }
  166. }