registry.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. package weightmap
  2. import (
  3. "fmt"
  4. "strings"
  5. "sync"
  6. )
  7. // WeightMapper maps HuggingFace weight names to internal structure.
  8. // Returns:
  9. // - layerIdx: layer index for layer-scoped weights
  10. // - fieldName: internal field name
  11. // - ok: whether mapping matched
  12. type WeightMapper interface {
  13. MapName(name string) (layerIdx int, fieldName string, ok bool)
  14. }
  15. // MapperFunc is a helper to turn a function into a WeightMapper.
  16. type MapperFunc func(name string) (int, string, bool)
  17. func (f MapperFunc) MapName(name string) (int, string, bool) { return f(name) }
  18. var (
  19. mu sync.RWMutex
  20. mappers = map[string]WeightMapper{}
  21. fallback WeightMapper
  22. )
  23. // Register registers a mapper for an architecture key.
  24. // The key should be stable (e.g. "qwen3").
  25. func Register(key string, mapper WeightMapper) {
  26. key = strings.ToLower(key)
  27. mu.Lock()
  28. defer mu.Unlock()
  29. mappers[key] = mapper
  30. if fallback == nil {
  31. fallback = mapper
  32. }
  33. }
  34. // Get returns a mapper by key. If missing, returns the first registered mapper (if any).
  35. func Get(key string) (WeightMapper, bool) {
  36. key = strings.ToLower(key)
  37. mu.RLock()
  38. defer mu.RUnlock()
  39. m, ok := mappers[key]
  40. if ok {
  41. return m, true
  42. }
  43. return fallback, false
  44. }
  45. // GetForArchitecture resolves a mapper for a raw architecture string.
  46. // It first tries exact match, then substring match over registered keys.
  47. func GetForArchitecture(arch string) WeightMapper {
  48. archLower := strings.ToLower(arch)
  49. if m, ok := Get(archLower); ok {
  50. return m
  51. }
  52. mu.RLock()
  53. defer mu.RUnlock()
  54. for k, m := range mappers {
  55. if strings.Contains(archLower, k) {
  56. return m
  57. }
  58. }
  59. return fallback
  60. }
  61. // ParseLayerName extracts layer index and field name from HF weight name.
  62. func ParseLayerName(name, prefix string, fields map[string]string) (int, string, bool) {
  63. if !strings.HasPrefix(name, prefix) {
  64. return -1, "", false
  65. }
  66. rest := name[len(prefix):]
  67. var layerIdx int
  68. var suffix string
  69. if _, err := fmt.Sscanf(rest, "%d.%s", &layerIdx, &suffix); err != nil {
  70. return -1, "", false
  71. }
  72. if field, ok := fields[suffix]; ok {
  73. return layerIdx, field, true
  74. }
  75. return layerIdx, suffix, true
  76. }