// Package sample provides token sampling strategies for text generation package sample import ( "container/heap" "math" "math/rand" "sort" ) // Config holds sampling parameters type Config struct { Temperature float32 // Temperature for logit scaling (0 = greedy) TopK int // Keep only top K tokens (0 = disabled) TopP float32 // Nucleus sampling threshold (1.0 = disabled) MinP float32 // Minimum probability relative to max (0 = disabled) RepetitionPenalty float32 // Penalty for repeated tokens (1.0 = disabled) Seed int64 // Random seed (-1 = random) } // DefaultConfig returns sensible defaults func DefaultConfig() Config { return Config{ Temperature: 0.7, TopK: 40, TopP: 0.9, MinP: 0.0, RepetitionPenalty: 1.1, Seed: -1, } } // Sampler handles token sampling with various strategies type Sampler struct { config Config rng *rand.Rand } // New creates a sampler with given config func New(cfg Config) *Sampler { var rng *rand.Rand if cfg.Seed >= 0 { rng = rand.New(rand.NewSource(cfg.Seed)) } else { rng = rand.New(rand.NewSource(rand.Int63())) } return &Sampler{ config: cfg, rng: rng, } } // Config returns a copy of the sampler configuration. func (s *Sampler) Config() Config { if s == nil { return Config{} } return s.config } // candidate holds token id and its score type candidate struct { id int score float32 } type candidateHeap []candidate func (h candidateHeap) Len() int { return len(h) } func (h candidateHeap) Less(i, j int) bool { // min-heap by score return h[i].score < h[j].score } func (h candidateHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } func (h *candidateHeap) Push(x interface{}) { *h = append(*h, x.(candidate)) } func (h *candidateHeap) Pop() interface{} { old := *h n := len(old) x := old[n-1] *h = old[:n-1] return x } // Sample selects a token from logits using configured strategies // recentTokens is used for repetition penalty func (s *Sampler) Sample(logits []float32, recentTokens []int) int { if len(logits) == 0 { return 0 } var seen map[int]struct{} if s.config.RepetitionPenalty != 1.0 && len(recentTokens) > 0 { seen = make(map[int]struct{}, len(recentTokens)) for _, t := range recentTokens { seen[t] = struct{}{} } } // Temperature = 0 means greedy if s.config.Temperature == 0 { maxIdx := 0 maxScore := logits[0] if seen != nil { if _, ok := seen[0]; ok { if maxScore > 0 { maxScore /= s.config.RepetitionPenalty } else { maxScore *= s.config.RepetitionPenalty } } } for i := 1; i < len(logits); i++ { score := logits[i] if seen != nil { if _, ok := seen[i]; ok { if score > 0 { score /= s.config.RepetitionPenalty } else { score *= s.config.RepetitionPenalty } } } if score > maxScore { maxScore = score maxIdx = i } } return maxIdx } // Fast path: if TopK is enabled, avoid allocating/sorting full vocab. if s.config.TopK > 0 && s.config.TopK < len(logits) { candidates := topKHeap(logits, s.config.TopK, seen, s.config.RepetitionPenalty) // Sort descending sort.Slice(candidates, func(i, j int) bool { return candidates[i].score > candidates[j].score }) // Temperature scaling applyTemperature(candidates, s.config.Temperature) // Softmax applySoftmax(candidates) // Top-P (nucleus) filtering candidates = topP(candidates, s.config.TopP) // Min-P filtering candidates = minP(candidates, s.config.MinP) // Random sampling from remaining candidates return sampleFromProbs(candidates, s.rng) } // Fallback: full-vocab path (slower, but preserves behavior for TopK<=0) candidates := make([]candidate, len(logits)) for i, l := range logits { candidates[i] = candidate{id: i, score: l} } if seen != nil { applyRepetitionPenalty(candidates, recentTokens, s.config.RepetitionPenalty) } // Top-K filtering and sorting candidates = topK(candidates, s.config.TopK) // Temperature scaling applyTemperature(candidates, s.config.Temperature) // Softmax applySoftmax(candidates) // Top-P (nucleus) filtering candidates = topP(candidates, s.config.TopP) // Min-P filtering candidates = minP(candidates, s.config.MinP) // Random sampling from remaining candidates return sampleFromProbs(candidates, s.rng) } func topKHeap(logits []float32, k int, seen map[int]struct{}, penalty float32) []candidate { if k <= 0 { k = len(logits) } if k > len(logits) { k = len(logits) } h := make(candidateHeap, 0, k) for i, l := range logits { score := l if seen != nil { if _, ok := seen[i]; ok { if score > 0 { score /= penalty } else { score *= penalty } } } c := candidate{id: i, score: score} if len(h) < k { heap.Push(&h, c) continue } if k > 0 && c.score > h[0].score { h[0] = c heap.Fix(&h, 0) } } // Return as a normal slice (unsorted) out := make([]candidate, len(h)) copy(out, h) return out } // SampleFromTopK samples from an already top-k filtered candidate list. // The provided logits must already include repetition penalty if enabled. func (s *Sampler) SampleFromTopK(ids []int32, logits []float32) int { if len(ids) == 0 || len(logits) == 0 { return 0 } if len(ids) != len(logits) { // best-effort: clamp to shorter n := len(ids) if len(logits) < n { n = len(logits) } ids = ids[:n] logits = logits[:n] } // Greedy path if s.config.Temperature == 0 { maxIdx := 0 maxScore := logits[0] for i := 1; i < len(logits); i++ { if logits[i] > maxScore { maxScore = logits[i] maxIdx = i } } return int(ids[maxIdx]) } candidates := make([]candidate, len(ids)) for i := range ids { candidates[i] = candidate{id: int(ids[i]), score: logits[i]} } // Sort descending sort.Slice(candidates, func(i, j int) bool { return candidates[i].score > candidates[j].score }) applyTemperature(candidates, s.config.Temperature) applySoftmax(candidates) candidates = topP(candidates, s.config.TopP) candidates = minP(candidates, s.config.MinP) return sampleFromProbs(candidates, s.rng) } // SampleGreedy returns the highest probability token func SampleGreedy(logits []float32) int { maxIdx := 0 maxVal := logits[0] for i := 1; i < len(logits); i++ { if logits[i] > maxVal { maxVal = logits[i] maxIdx = i } } return maxIdx } // greedy returns id of highest scoring candidate func greedy(candidates []candidate) int { maxIdx := 0 maxScore := candidates[0].score for i := 1; i < len(candidates); i++ { if candidates[i].score > maxScore { maxScore = candidates[i].score maxIdx = i } } return candidates[maxIdx].id } // topK keeps only the k highest scoring candidates (sorted descending) func topK(candidates []candidate, k int) []candidate { if k <= 0 || k >= len(candidates) { // Sort all descending sort.Slice(candidates, func(i, j int) bool { return candidates[i].score > candidates[j].score }) return candidates } // Sort all descending and take top k sort.Slice(candidates, func(i, j int) bool { return candidates[i].score > candidates[j].score }) return candidates[:k] } // applyTemperature scales logits by temperature func applyTemperature(candidates []candidate, temp float32) { if temp < 1e-7 { temp = 1e-7 // Avoid division by zero } for i := range candidates { candidates[i].score /= temp } } // applySoftmax converts logits to probabilities func applySoftmax(candidates []candidate) { // Find max for numerical stability maxScore := candidates[0].score for _, c := range candidates[1:] { if c.score > maxScore { maxScore = c.score } } // Compute exp(x - max) var sum float32 for i := range candidates { candidates[i].score = float32(math.Exp(float64(candidates[i].score - maxScore))) sum += candidates[i].score } // Normalize for i := range candidates { candidates[i].score /= sum } } // topP keeps tokens until cumulative probability exceeds p (nucleus sampling) func topP(candidates []candidate, p float32) []candidate { if p >= 1.0 { return candidates } var cumSum float32 for i, c := range candidates { cumSum += c.score if cumSum > p { if i == 0 { return candidates[:1] } return candidates[:i+1] } } return candidates } // minP keeps tokens with probability >= p * max_probability func minP(candidates []candidate, p float32) []candidate { if p <= 0 || len(candidates) == 0 { return candidates } maxProb := candidates[0].score // Assumes sorted descending threshold := maxProb * p for i, c := range candidates { if c.score < threshold { if i == 0 { return candidates[:1] } return candidates[:i] } } return candidates } // applyRepetitionPenalty penalizes recently used tokens func applyRepetitionPenalty(candidates []candidate, recentTokens []int, penalty float32) { seen := make(map[int]bool) for _, t := range recentTokens { seen[t] = true } for i := range candidates { if seen[candidates[i].id] { if candidates[i].score > 0 { candidates[i].score /= penalty } else { candidates[i].score *= penalty } } } } // sampleFromProbs randomly selects a token based on probability distribution func sampleFromProbs(candidates []candidate, rng *rand.Rand) int { if len(candidates) == 0 { return 0 } if len(candidates) == 1 { return candidates[0].id } // Compute cumulative sum cumSum := make([]float32, len(candidates)) cumSum[0] = candidates[0].score for i := 1; i < len(candidates); i++ { cumSum[i] = cumSum[i-1] + candidates[i].score } // Random sample r := rng.Float32() * cumSum[len(cumSum)-1] // Binary search for i, cs := range cumSum { if r <= cs { return candidates[i].id } } return candidates[len(candidates)-1].id }