| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430 |
- // Package sample provides token sampling strategies for text generation
- package sample
- import (
- "container/heap"
- "math"
- "math/rand"
- "sort"
- )
- // Config holds sampling parameters
- type Config struct {
- Temperature float32 // Temperature for logit scaling (0 = greedy)
- TopK int // Keep only top K tokens (0 = disabled)
- TopP float32 // Nucleus sampling threshold (1.0 = disabled)
- MinP float32 // Minimum probability relative to max (0 = disabled)
- RepetitionPenalty float32 // Penalty for repeated tokens (1.0 = disabled)
- Seed int64 // Random seed (-1 = random)
- }
- // DefaultConfig returns sensible defaults
- func DefaultConfig() Config {
- return Config{
- Temperature: 0.7,
- TopK: 40,
- TopP: 0.9,
- MinP: 0.0,
- RepetitionPenalty: 1.1,
- Seed: -1,
- }
- }
- // Sampler handles token sampling with various strategies
- type Sampler struct {
- config Config
- rng *rand.Rand
- }
- // New creates a sampler with given config
- func New(cfg Config) *Sampler {
- var rng *rand.Rand
- if cfg.Seed >= 0 {
- rng = rand.New(rand.NewSource(cfg.Seed))
- } else {
- rng = rand.New(rand.NewSource(rand.Int63()))
- }
- return &Sampler{
- config: cfg,
- rng: rng,
- }
- }
- // Config returns a copy of the sampler configuration.
- func (s *Sampler) Config() Config {
- if s == nil {
- return Config{}
- }
- return s.config
- }
- // candidate holds token id and its score
- type candidate struct {
- id int
- score float32
- }
- type candidateHeap []candidate
- func (h candidateHeap) Len() int { return len(h) }
- func (h candidateHeap) Less(i, j int) bool {
- // min-heap by score
- return h[i].score < h[j].score
- }
- func (h candidateHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
- func (h *candidateHeap) Push(x interface{}) {
- *h = append(*h, x.(candidate))
- }
- func (h *candidateHeap) Pop() interface{} {
- old := *h
- n := len(old)
- x := old[n-1]
- *h = old[:n-1]
- return x
- }
- // Sample selects a token from logits using configured strategies
- // recentTokens is used for repetition penalty
- func (s *Sampler) Sample(logits []float32, recentTokens []int) int {
- if len(logits) == 0 {
- return 0
- }
- var seen map[int]struct{}
- if s.config.RepetitionPenalty != 1.0 && len(recentTokens) > 0 {
- seen = make(map[int]struct{}, len(recentTokens))
- for _, t := range recentTokens {
- seen[t] = struct{}{}
- }
- }
- // Temperature = 0 means greedy
- if s.config.Temperature == 0 {
- maxIdx := 0
- maxScore := logits[0]
- if seen != nil {
- if _, ok := seen[0]; ok {
- if maxScore > 0 {
- maxScore /= s.config.RepetitionPenalty
- } else {
- maxScore *= s.config.RepetitionPenalty
- }
- }
- }
- for i := 1; i < len(logits); i++ {
- score := logits[i]
- if seen != nil {
- if _, ok := seen[i]; ok {
- if score > 0 {
- score /= s.config.RepetitionPenalty
- } else {
- score *= s.config.RepetitionPenalty
- }
- }
- }
- if score > maxScore {
- maxScore = score
- maxIdx = i
- }
- }
- return maxIdx
- }
- // Fast path: if TopK is enabled, avoid allocating/sorting full vocab.
- if s.config.TopK > 0 && s.config.TopK < len(logits) {
- candidates := topKHeap(logits, s.config.TopK, seen, s.config.RepetitionPenalty)
- // Sort descending
- sort.Slice(candidates, func(i, j int) bool {
- return candidates[i].score > candidates[j].score
- })
- // Temperature scaling
- applyTemperature(candidates, s.config.Temperature)
- // Softmax
- applySoftmax(candidates)
- // Top-P (nucleus) filtering
- candidates = topP(candidates, s.config.TopP)
- // Min-P filtering
- candidates = minP(candidates, s.config.MinP)
- // Random sampling from remaining candidates
- return sampleFromProbs(candidates, s.rng)
- }
- // Fallback: full-vocab path (slower, but preserves behavior for TopK<=0)
- candidates := make([]candidate, len(logits))
- for i, l := range logits {
- candidates[i] = candidate{id: i, score: l}
- }
- if seen != nil {
- applyRepetitionPenalty(candidates, recentTokens, s.config.RepetitionPenalty)
- }
- // Top-K filtering and sorting
- candidates = topK(candidates, s.config.TopK)
- // Temperature scaling
- applyTemperature(candidates, s.config.Temperature)
- // Softmax
- applySoftmax(candidates)
- // Top-P (nucleus) filtering
- candidates = topP(candidates, s.config.TopP)
- // Min-P filtering
- candidates = minP(candidates, s.config.MinP)
- // Random sampling from remaining candidates
- return sampleFromProbs(candidates, s.rng)
- }
- func topKHeap(logits []float32, k int, seen map[int]struct{}, penalty float32) []candidate {
- if k <= 0 {
- k = len(logits)
- }
- if k > len(logits) {
- k = len(logits)
- }
- h := make(candidateHeap, 0, k)
- for i, l := range logits {
- score := l
- if seen != nil {
- if _, ok := seen[i]; ok {
- if score > 0 {
- score /= penalty
- } else {
- score *= penalty
- }
- }
- }
- c := candidate{id: i, score: score}
- if len(h) < k {
- heap.Push(&h, c)
- continue
- }
- if k > 0 && c.score > h[0].score {
- h[0] = c
- heap.Fix(&h, 0)
- }
- }
- // Return as a normal slice (unsorted)
- out := make([]candidate, len(h))
- copy(out, h)
- return out
- }
- // SampleFromTopK samples from an already top-k filtered candidate list.
- // The provided logits must already include repetition penalty if enabled.
- func (s *Sampler) SampleFromTopK(ids []int32, logits []float32) int {
- if len(ids) == 0 || len(logits) == 0 {
- return 0
- }
- if len(ids) != len(logits) {
- // best-effort: clamp to shorter
- n := len(ids)
- if len(logits) < n {
- n = len(logits)
- }
- ids = ids[:n]
- logits = logits[:n]
- }
- // Greedy path
- if s.config.Temperature == 0 {
- maxIdx := 0
- maxScore := logits[0]
- for i := 1; i < len(logits); i++ {
- if logits[i] > maxScore {
- maxScore = logits[i]
- maxIdx = i
- }
- }
- return int(ids[maxIdx])
- }
- candidates := make([]candidate, len(ids))
- for i := range ids {
- candidates[i] = candidate{id: int(ids[i]), score: logits[i]}
- }
- // Sort descending
- sort.Slice(candidates, func(i, j int) bool {
- return candidates[i].score > candidates[j].score
- })
- applyTemperature(candidates, s.config.Temperature)
- applySoftmax(candidates)
- candidates = topP(candidates, s.config.TopP)
- candidates = minP(candidates, s.config.MinP)
- return sampleFromProbs(candidates, s.rng)
- }
- // SampleGreedy returns the highest probability token
- func SampleGreedy(logits []float32) int {
- maxIdx := 0
- maxVal := logits[0]
- for i := 1; i < len(logits); i++ {
- if logits[i] > maxVal {
- maxVal = logits[i]
- maxIdx = i
- }
- }
- return maxIdx
- }
- // greedy returns id of highest scoring candidate
- func greedy(candidates []candidate) int {
- maxIdx := 0
- maxScore := candidates[0].score
- for i := 1; i < len(candidates); i++ {
- if candidates[i].score > maxScore {
- maxScore = candidates[i].score
- maxIdx = i
- }
- }
- return candidates[maxIdx].id
- }
- // topK keeps only the k highest scoring candidates (sorted descending)
- func topK(candidates []candidate, k int) []candidate {
- if k <= 0 || k >= len(candidates) {
- // Sort all descending
- sort.Slice(candidates, func(i, j int) bool {
- return candidates[i].score > candidates[j].score
- })
- return candidates
- }
- // Sort all descending and take top k
- sort.Slice(candidates, func(i, j int) bool {
- return candidates[i].score > candidates[j].score
- })
- return candidates[:k]
- }
- // applyTemperature scales logits by temperature
- func applyTemperature(candidates []candidate, temp float32) {
- if temp < 1e-7 {
- temp = 1e-7 // Avoid division by zero
- }
- for i := range candidates {
- candidates[i].score /= temp
- }
- }
- // applySoftmax converts logits to probabilities
- func applySoftmax(candidates []candidate) {
- // Find max for numerical stability
- maxScore := candidates[0].score
- for _, c := range candidates[1:] {
- if c.score > maxScore {
- maxScore = c.score
- }
- }
- // Compute exp(x - max)
- var sum float32
- for i := range candidates {
- candidates[i].score = float32(math.Exp(float64(candidates[i].score - maxScore)))
- sum += candidates[i].score
- }
- // Normalize
- for i := range candidates {
- candidates[i].score /= sum
- }
- }
- // topP keeps tokens until cumulative probability exceeds p (nucleus sampling)
- func topP(candidates []candidate, p float32) []candidate {
- if p >= 1.0 {
- return candidates
- }
- var cumSum float32
- for i, c := range candidates {
- cumSum += c.score
- if cumSum > p {
- if i == 0 {
- return candidates[:1]
- }
- return candidates[:i+1]
- }
- }
- return candidates
- }
- // minP keeps tokens with probability >= p * max_probability
- func minP(candidates []candidate, p float32) []candidate {
- if p <= 0 || len(candidates) == 0 {
- return candidates
- }
- maxProb := candidates[0].score // Assumes sorted descending
- threshold := maxProb * p
- for i, c := range candidates {
- if c.score < threshold {
- if i == 0 {
- return candidates[:1]
- }
- return candidates[:i]
- }
- }
- return candidates
- }
- // applyRepetitionPenalty penalizes recently used tokens
- func applyRepetitionPenalty(candidates []candidate, recentTokens []int, penalty float32) {
- seen := make(map[int]bool)
- for _, t := range recentTokens {
- seen[t] = true
- }
- for i := range candidates {
- if seen[candidates[i].id] {
- if candidates[i].score > 0 {
- candidates[i].score /= penalty
- } else {
- candidates[i].score *= penalty
- }
- }
- }
- }
- // sampleFromProbs randomly selects a token based on probability distribution
- func sampleFromProbs(candidates []candidate, rng *rand.Rand) int {
- if len(candidates) == 0 {
- return 0
- }
- if len(candidates) == 1 {
- return candidates[0].id
- }
- // Compute cumulative sum
- cumSum := make([]float32, len(candidates))
- cumSum[0] = candidates[0].score
- for i := 1; i < len(candidates); i++ {
- cumSum[i] = cumSum[i-1] + candidates[i].score
- }
- // Random sample
- r := rng.Float32() * cumSum[len(cumSum)-1]
- // Binary search
- for i, cs := range cumSum {
- if r <= cs {
- return candidates[i].id
- }
- }
- return candidates[len(candidates)-1].id
- }
|