loader.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. package tokenizer
  2. import (
  3. "encoding/json"
  4. "os"
  5. "strings"
  6. )
  7. // JSON structures for HuggingFace tokenizer.json format
  8. type addedTokenJSON struct {
  9. ID int `json:"id"`
  10. Content string `json:"content"`
  11. Special bool `json:"special"`
  12. }
  13. type modelJSON struct {
  14. Type string `json:"type"`
  15. Vocab map[string]int `json:"vocab"`
  16. Merges json.RawMessage `json:"merges"` // Can be []string or [][]string
  17. }
  18. type tokenizerJSON struct {
  19. AddedTokens []addedTokenJSON `json:"added_tokens"`
  20. Model modelJSON `json:"model"`
  21. }
  22. // LoadFromJSON loads a tokenizer from HuggingFace tokenizer.json format
  23. func LoadFromJSON(path string) (*Tokenizer, error) {
  24. data, err := os.ReadFile(path)
  25. if err != nil {
  26. return nil, err
  27. }
  28. return LoadFromBytes(data)
  29. }
  30. // LoadFromBytes loads a tokenizer from raw JSON data
  31. func LoadFromBytes(data []byte) (*Tokenizer, error) {
  32. var tJSON tokenizerJSON
  33. if err := json.Unmarshal(data, &tJSON); err != nil {
  34. return nil, err
  35. }
  36. // Parse merges - can be []string or [][]string
  37. merges := parseMerges(tJSON.Model.Merges)
  38. byteEnc, byteDec := buildByteEncoder()
  39. t := &Tokenizer{
  40. vocab: tJSON.Model.Vocab,
  41. idToToken: make(map[int]string),
  42. merges: make(map[string]int),
  43. addedTokens: make(map[string]int),
  44. byteEncoder: byteEnc,
  45. byteDecoder: byteDec,
  46. }
  47. // Build id -> token map
  48. for k, v := range t.vocab {
  49. t.idToToken[v] = k
  50. }
  51. // Parse merges: "a b" format, index = priority (lower = higher priority)
  52. for i, m := range merges {
  53. t.merges[m] = i
  54. }
  55. // Added tokens (special)
  56. for _, at := range tJSON.AddedTokens {
  57. t.addedTokens[at.Content] = at.ID
  58. t.idToToken[at.ID] = at.Content
  59. if strings.Contains(at.Content, "endoftext") || strings.Contains(at.Content, "im_end") {
  60. t.eosID = at.ID
  61. }
  62. }
  63. // Compile pre-tokenizer pattern based on tokenizer model type
  64. t.prePattern = CompilePattern(DetectPattern(tJSON.Model.Type))
  65. return t, nil
  66. }
  67. // parseMerges handles both []string and [][]string merge formats
  68. func parseMerges(raw json.RawMessage) []string {
  69. if len(raw) == 0 {
  70. return nil
  71. }
  72. // Try []string first
  73. var merges []string
  74. if err := json.Unmarshal(raw, &merges); err == nil {
  75. return merges
  76. }
  77. // Try [][]string
  78. var mergePairs [][]string
  79. if err := json.Unmarshal(raw, &mergePairs); err == nil {
  80. merges = make([]string, 0, len(mergePairs))
  81. for _, pair := range mergePairs {
  82. if len(pair) == 2 {
  83. merges = append(merges, pair[0]+" "+pair[1])
  84. }
  85. }
  86. return merges
  87. }
  88. return nil
  89. }