1
0

tokenizer.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. // Package tokenizer provides BPE tokenization for LLM models
  2. package tokenizer
  3. import (
  4. "regexp"
  5. "sort"
  6. "strings"
  7. )
  8. // Tokenizer handles text tokenization using byte-level BPE
  9. type Tokenizer struct {
  10. vocab map[string]int // token -> id
  11. idToToken map[int]string // id -> token
  12. merges map[string]int // "a b" -> rank
  13. byteEncoder map[byte]rune // byte -> unicode
  14. byteDecoder map[rune]byte // unicode -> byte
  15. addedTokens map[string]int // special tokens
  16. eosID int // end of sequence token
  17. prePattern *regexp.Regexp // pre-tokenizer regex
  18. }
  19. // Encode converts text to token IDs
  20. func (t *Tokenizer) Encode(text string) []int {
  21. var ids []int
  22. // Split on special tokens first
  23. segments := t.splitOnSpecialTokens(text)
  24. // Process each segment
  25. for _, seg := range segments {
  26. if seg.isSpecial {
  27. ids = append(ids, seg.id)
  28. } else {
  29. ids = append(ids, t.encodeText(seg.text)...)
  30. }
  31. }
  32. return ids
  33. }
  34. // encodeText tokenizes regular text (non-special tokens)
  35. func (t *Tokenizer) encodeText(text string) []int {
  36. var ids []int
  37. // Pre-tokenization
  38. var chunks []string
  39. if t.prePattern != nil {
  40. chunks = t.prePattern.FindAllString(text, -1)
  41. } else {
  42. chunks = []string{text}
  43. }
  44. // BPE for each chunk
  45. for _, chunk := range chunks {
  46. byteRep := t.bytesToTokens([]byte(chunk))
  47. tokens := t.bpe(byteRep)
  48. for _, tok := range tokens {
  49. if id, ok := t.vocab[tok]; ok {
  50. ids = append(ids, id)
  51. }
  52. }
  53. }
  54. return ids
  55. }
  56. // Decode converts token IDs back to text
  57. func (t *Tokenizer) Decode(ids []int) string {
  58. var tokens []string
  59. for _, id := range ids {
  60. if tok, ok := t.idToToken[id]; ok {
  61. tokens = append(tokens, tok)
  62. }
  63. }
  64. text := strings.Join(tokens, "")
  65. return string(t.tokensToBytes(text))
  66. }
  67. // EosID returns the end-of-sequence token ID
  68. func (t *Tokenizer) EosID() int {
  69. return t.eosID
  70. }
  71. // VocabSize returns the vocabulary size
  72. func (t *Tokenizer) VocabSize() int {
  73. return len(t.vocab) + len(t.addedTokens)
  74. }
  75. // GetToken returns the token string for a given ID
  76. func (t *Tokenizer) GetToken(id int) (string, bool) {
  77. tok, ok := t.idToToken[id]
  78. return tok, ok
  79. }
  80. func (t *Tokenizer) AddedTokenStrings() []string {
  81. out := make([]string, 0, len(t.addedTokens))
  82. for s := range t.addedTokens {
  83. out = append(out, s)
  84. }
  85. sort.Strings(out)
  86. return out
  87. }