pretokenizer.go 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. package tokenizer
  2. import "regexp"
  3. // Pre-tokenizer patterns for different model families
  4. // QwenPattern is the pre-tokenizer regex for Qwen/Qwen2/Qwen3 models
  5. const QwenPattern = `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
  6. const KimiPattern = `[\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
  7. // LlamaPattern is the pre-tokenizer regex for Llama models
  8. const LlamaPattern = `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`
  9. // GPT2Pattern is the pre-tokenizer regex for GPT-2 style models
  10. const GPT2Pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
  11. // CompilePattern compiles a pre-tokenizer pattern
  12. func CompilePattern(pattern string) *regexp.Regexp {
  13. re, _ := regexp.Compile(pattern)
  14. return re
  15. }
  16. // DetectPattern returns the appropriate pattern based on model type
  17. func DetectPattern(modelType string) string {
  18. switch {
  19. case contains(modelType, "kimi"):
  20. return KimiPattern
  21. case contains(modelType, "qwen"):
  22. return QwenPattern
  23. case contains(modelType, "llama"):
  24. return LlamaPattern
  25. case contains(modelType, "gpt"):
  26. return GPT2Pattern
  27. default:
  28. return QwenPattern // Default to Qwen pattern
  29. }
  30. }
  31. func contains(s, substr string) bool {
  32. return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsLower(s, substr))
  33. }
  34. func containsLower(s, substr string) bool {
  35. for i := 0; i <= len(s)-len(substr); i++ {
  36. if eqFoldAt(s, i, substr) {
  37. return true
  38. }
  39. }
  40. return false
  41. }
  42. func eqFoldAt(s string, i int, substr string) bool {
  43. for j := 0; j < len(substr); j++ {
  44. c1, c2 := s[i+j], substr[j]
  45. if c1 != c2 && toLower(c1) != toLower(c2) {
  46. return false
  47. }
  48. }
  49. return true
  50. }
  51. func toLower(c byte) byte {
  52. if c >= 'A' && c <= 'Z' {
  53. return c + 32
  54. }
  55. return c
  56. }