| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990 |
- package weightmap
- import (
- "fmt"
- "strings"
- "sync"
- )
- // WeightMapper maps HuggingFace weight names to internal structure.
- // Returns:
- // - layerIdx: layer index for layer-scoped weights
- // - fieldName: internal field name
- // - ok: whether mapping matched
- type WeightMapper interface {
- MapName(name string) (layerIdx int, fieldName string, ok bool)
- }
- // MapperFunc is a helper to turn a function into a WeightMapper.
- type MapperFunc func(name string) (int, string, bool)
- func (f MapperFunc) MapName(name string) (int, string, bool) { return f(name) }
- var (
- mu sync.RWMutex
- mappers = map[string]WeightMapper{}
- fallback WeightMapper
- )
- // Register registers a mapper for an architecture key.
- // The key should be stable (e.g. "qwen3").
- func Register(key string, mapper WeightMapper) {
- key = strings.ToLower(key)
- mu.Lock()
- defer mu.Unlock()
- mappers[key] = mapper
- if fallback == nil {
- fallback = mapper
- }
- }
- // Get returns a mapper by key. If missing, returns the first registered mapper (if any).
- func Get(key string) (WeightMapper, bool) {
- key = strings.ToLower(key)
- mu.RLock()
- defer mu.RUnlock()
- m, ok := mappers[key]
- if ok {
- return m, true
- }
- return fallback, false
- }
- // GetForArchitecture resolves a mapper for a raw architecture string.
- // It first tries exact match, then substring match over registered keys.
- func GetForArchitecture(arch string) WeightMapper {
- archLower := strings.ToLower(arch)
- if m, ok := Get(archLower); ok {
- return m
- }
- mu.RLock()
- defer mu.RUnlock()
- for k, m := range mappers {
- if strings.Contains(archLower, k) {
- return m
- }
- }
- return fallback
- }
- // ParseLayerName extracts layer index and field name from HF weight name.
- func ParseLayerName(name, prefix string, fields map[string]string) (int, string, bool) {
- if !strings.HasPrefix(name, prefix) {
- return -1, "", false
- }
- rest := name[len(prefix):]
- var layerIdx int
- var suffix string
- if _, err := fmt.Sscanf(rest, "%d.%s", &layerIdx, &suffix); err != nil {
- return -1, "", false
- }
- if field, ok := fields[suffix]; ok {
- return layerIdx, field, true
- }
- return layerIdx, suffix, true
- }
|