sampler.go 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. // Package sample provides token sampling strategies for text generation
  2. package sample
  3. import (
  4. "container/heap"
  5. "math"
  6. "math/rand"
  7. "sort"
  8. )
  9. // Config holds sampling parameters
  10. type Config struct {
  11. Temperature float32 // Temperature for logit scaling (0 = greedy)
  12. TopK int // Keep only top K tokens (0 = disabled)
  13. TopP float32 // Nucleus sampling threshold (1.0 = disabled)
  14. MinP float32 // Minimum probability relative to max (0 = disabled)
  15. RepetitionPenalty float32 // Penalty for repeated tokens (1.0 = disabled)
  16. Seed int64 // Random seed (-1 = random)
  17. }
  18. // DefaultConfig returns sensible defaults
  19. func DefaultConfig() Config {
  20. return Config{
  21. Temperature: 0.7,
  22. TopK: 40,
  23. TopP: 0.9,
  24. MinP: 0.0,
  25. RepetitionPenalty: 1.1,
  26. Seed: -1,
  27. }
  28. }
  29. // Sampler handles token sampling with various strategies
  30. type Sampler struct {
  31. config Config
  32. rng *rand.Rand
  33. }
  34. // New creates a sampler with given config
  35. func New(cfg Config) *Sampler {
  36. var rng *rand.Rand
  37. if cfg.Seed >= 0 {
  38. rng = rand.New(rand.NewSource(cfg.Seed))
  39. } else {
  40. rng = rand.New(rand.NewSource(rand.Int63()))
  41. }
  42. return &Sampler{
  43. config: cfg,
  44. rng: rng,
  45. }
  46. }
  47. // Config returns a copy of the sampler configuration.
  48. func (s *Sampler) Config() Config {
  49. if s == nil {
  50. return Config{}
  51. }
  52. return s.config
  53. }
  54. // candidate holds token id and its score
  55. type candidate struct {
  56. id int
  57. score float32
  58. }
  59. type candidateHeap []candidate
  60. func (h candidateHeap) Len() int { return len(h) }
  61. func (h candidateHeap) Less(i, j int) bool {
  62. // min-heap by score
  63. return h[i].score < h[j].score
  64. }
  65. func (h candidateHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
  66. func (h *candidateHeap) Push(x interface{}) {
  67. *h = append(*h, x.(candidate))
  68. }
  69. func (h *candidateHeap) Pop() interface{} {
  70. old := *h
  71. n := len(old)
  72. x := old[n-1]
  73. *h = old[:n-1]
  74. return x
  75. }
  76. // Sample selects a token from logits using configured strategies
  77. // recentTokens is used for repetition penalty
  78. func (s *Sampler) Sample(logits []float32, recentTokens []int) int {
  79. if len(logits) == 0 {
  80. return 0
  81. }
  82. var seen map[int]struct{}
  83. if s.config.RepetitionPenalty != 1.0 && len(recentTokens) > 0 {
  84. seen = make(map[int]struct{}, len(recentTokens))
  85. for _, t := range recentTokens {
  86. seen[t] = struct{}{}
  87. }
  88. }
  89. // Temperature = 0 means greedy
  90. if s.config.Temperature == 0 {
  91. maxIdx := 0
  92. maxScore := logits[0]
  93. if seen != nil {
  94. if _, ok := seen[0]; ok {
  95. if maxScore > 0 {
  96. maxScore /= s.config.RepetitionPenalty
  97. } else {
  98. maxScore *= s.config.RepetitionPenalty
  99. }
  100. }
  101. }
  102. for i := 1; i < len(logits); i++ {
  103. score := logits[i]
  104. if seen != nil {
  105. if _, ok := seen[i]; ok {
  106. if score > 0 {
  107. score /= s.config.RepetitionPenalty
  108. } else {
  109. score *= s.config.RepetitionPenalty
  110. }
  111. }
  112. }
  113. if score > maxScore {
  114. maxScore = score
  115. maxIdx = i
  116. }
  117. }
  118. return maxIdx
  119. }
  120. // Fast path: if TopK is enabled, avoid allocating/sorting full vocab.
  121. if s.config.TopK > 0 && s.config.TopK < len(logits) {
  122. candidates := topKHeap(logits, s.config.TopK, seen, s.config.RepetitionPenalty)
  123. // Sort descending
  124. sort.Slice(candidates, func(i, j int) bool {
  125. return candidates[i].score > candidates[j].score
  126. })
  127. // Temperature scaling
  128. applyTemperature(candidates, s.config.Temperature)
  129. // Softmax
  130. applySoftmax(candidates)
  131. // Top-P (nucleus) filtering
  132. candidates = topP(candidates, s.config.TopP)
  133. // Min-P filtering
  134. candidates = minP(candidates, s.config.MinP)
  135. // Random sampling from remaining candidates
  136. return sampleFromProbs(candidates, s.rng)
  137. }
  138. // Fallback: full-vocab path (slower, but preserves behavior for TopK<=0)
  139. candidates := make([]candidate, len(logits))
  140. for i, l := range logits {
  141. candidates[i] = candidate{id: i, score: l}
  142. }
  143. if seen != nil {
  144. applyRepetitionPenalty(candidates, recentTokens, s.config.RepetitionPenalty)
  145. }
  146. // Top-K filtering and sorting
  147. candidates = topK(candidates, s.config.TopK)
  148. // Temperature scaling
  149. applyTemperature(candidates, s.config.Temperature)
  150. // Softmax
  151. applySoftmax(candidates)
  152. // Top-P (nucleus) filtering
  153. candidates = topP(candidates, s.config.TopP)
  154. // Min-P filtering
  155. candidates = minP(candidates, s.config.MinP)
  156. // Random sampling from remaining candidates
  157. return sampleFromProbs(candidates, s.rng)
  158. }
  159. func topKHeap(logits []float32, k int, seen map[int]struct{}, penalty float32) []candidate {
  160. if k <= 0 {
  161. k = len(logits)
  162. }
  163. if k > len(logits) {
  164. k = len(logits)
  165. }
  166. h := make(candidateHeap, 0, k)
  167. for i, l := range logits {
  168. score := l
  169. if seen != nil {
  170. if _, ok := seen[i]; ok {
  171. if score > 0 {
  172. score /= penalty
  173. } else {
  174. score *= penalty
  175. }
  176. }
  177. }
  178. c := candidate{id: i, score: score}
  179. if len(h) < k {
  180. heap.Push(&h, c)
  181. continue
  182. }
  183. if k > 0 && c.score > h[0].score {
  184. h[0] = c
  185. heap.Fix(&h, 0)
  186. }
  187. }
  188. // Return as a normal slice (unsorted)
  189. out := make([]candidate, len(h))
  190. copy(out, h)
  191. return out
  192. }
  193. // SampleFromTopK samples from an already top-k filtered candidate list.
  194. // The provided logits must already include repetition penalty if enabled.
  195. func (s *Sampler) SampleFromTopK(ids []int32, logits []float32) int {
  196. if len(ids) == 0 || len(logits) == 0 {
  197. return 0
  198. }
  199. if len(ids) != len(logits) {
  200. // best-effort: clamp to shorter
  201. n := len(ids)
  202. if len(logits) < n {
  203. n = len(logits)
  204. }
  205. ids = ids[:n]
  206. logits = logits[:n]
  207. }
  208. // Greedy path
  209. if s.config.Temperature == 0 {
  210. maxIdx := 0
  211. maxScore := logits[0]
  212. for i := 1; i < len(logits); i++ {
  213. if logits[i] > maxScore {
  214. maxScore = logits[i]
  215. maxIdx = i
  216. }
  217. }
  218. return int(ids[maxIdx])
  219. }
  220. candidates := make([]candidate, len(ids))
  221. for i := range ids {
  222. candidates[i] = candidate{id: int(ids[i]), score: logits[i]}
  223. }
  224. // Sort descending
  225. sort.Slice(candidates, func(i, j int) bool {
  226. return candidates[i].score > candidates[j].score
  227. })
  228. applyTemperature(candidates, s.config.Temperature)
  229. applySoftmax(candidates)
  230. candidates = topP(candidates, s.config.TopP)
  231. candidates = minP(candidates, s.config.MinP)
  232. return sampleFromProbs(candidates, s.rng)
  233. }
  234. // SampleGreedy returns the highest probability token
  235. func SampleGreedy(logits []float32) int {
  236. maxIdx := 0
  237. maxVal := logits[0]
  238. for i := 1; i < len(logits); i++ {
  239. if logits[i] > maxVal {
  240. maxVal = logits[i]
  241. maxIdx = i
  242. }
  243. }
  244. return maxIdx
  245. }
  246. // greedy returns id of highest scoring candidate
  247. func greedy(candidates []candidate) int {
  248. maxIdx := 0
  249. maxScore := candidates[0].score
  250. for i := 1; i < len(candidates); i++ {
  251. if candidates[i].score > maxScore {
  252. maxScore = candidates[i].score
  253. maxIdx = i
  254. }
  255. }
  256. return candidates[maxIdx].id
  257. }
  258. // topK keeps only the k highest scoring candidates (sorted descending)
  259. func topK(candidates []candidate, k int) []candidate {
  260. if k <= 0 || k >= len(candidates) {
  261. // Sort all descending
  262. sort.Slice(candidates, func(i, j int) bool {
  263. return candidates[i].score > candidates[j].score
  264. })
  265. return candidates
  266. }
  267. // Sort all descending and take top k
  268. sort.Slice(candidates, func(i, j int) bool {
  269. return candidates[i].score > candidates[j].score
  270. })
  271. return candidates[:k]
  272. }
  273. // applyTemperature scales logits by temperature
  274. func applyTemperature(candidates []candidate, temp float32) {
  275. if temp < 1e-7 {
  276. temp = 1e-7 // Avoid division by zero
  277. }
  278. for i := range candidates {
  279. candidates[i].score /= temp
  280. }
  281. }
  282. // applySoftmax converts logits to probabilities
  283. func applySoftmax(candidates []candidate) {
  284. // Find max for numerical stability
  285. maxScore := candidates[0].score
  286. for _, c := range candidates[1:] {
  287. if c.score > maxScore {
  288. maxScore = c.score
  289. }
  290. }
  291. // Compute exp(x - max)
  292. var sum float32
  293. for i := range candidates {
  294. candidates[i].score = float32(math.Exp(float64(candidates[i].score - maxScore)))
  295. sum += candidates[i].score
  296. }
  297. // Normalize
  298. for i := range candidates {
  299. candidates[i].score /= sum
  300. }
  301. }
  302. // topP keeps tokens until cumulative probability exceeds p (nucleus sampling)
  303. func topP(candidates []candidate, p float32) []candidate {
  304. if p >= 1.0 {
  305. return candidates
  306. }
  307. var cumSum float32
  308. for i, c := range candidates {
  309. cumSum += c.score
  310. if cumSum > p {
  311. if i == 0 {
  312. return candidates[:1]
  313. }
  314. return candidates[:i+1]
  315. }
  316. }
  317. return candidates
  318. }
  319. // minP keeps tokens with probability >= p * max_probability
  320. func minP(candidates []candidate, p float32) []candidate {
  321. if p <= 0 || len(candidates) == 0 {
  322. return candidates
  323. }
  324. maxProb := candidates[0].score // Assumes sorted descending
  325. threshold := maxProb * p
  326. for i, c := range candidates {
  327. if c.score < threshold {
  328. if i == 0 {
  329. return candidates[:1]
  330. }
  331. return candidates[:i]
  332. }
  333. }
  334. return candidates
  335. }
  336. // applyRepetitionPenalty penalizes recently used tokens
  337. func applyRepetitionPenalty(candidates []candidate, recentTokens []int, penalty float32) {
  338. seen := make(map[int]bool)
  339. for _, t := range recentTokens {
  340. seen[t] = true
  341. }
  342. for i := range candidates {
  343. if seen[candidates[i].id] {
  344. if candidates[i].score > 0 {
  345. candidates[i].score /= penalty
  346. } else {
  347. candidates[i].score *= penalty
  348. }
  349. }
  350. }
  351. }
  352. // sampleFromProbs randomly selects a token based on probability distribution
  353. func sampleFromProbs(candidates []candidate, rng *rand.Rand) int {
  354. if len(candidates) == 0 {
  355. return 0
  356. }
  357. if len(candidates) == 1 {
  358. return candidates[0].id
  359. }
  360. // Compute cumulative sum
  361. cumSum := make([]float32, len(candidates))
  362. cumSum[0] = candidates[0].score
  363. for i := 1; i < len(candidates); i++ {
  364. cumSum[i] = cumSum[i-1] + candidates[i].score
  365. }
  366. // Random sample
  367. r := rng.Float32() * cumSum[len(cumSum)-1]
  368. // Binary search
  369. for i, cs := range cumSum {
  370. if r <= cs {
  371. return candidates[i].id
  372. }
  373. }
  374. return candidates[len(candidates)-1].id
  375. }