convert.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962
  1. package convert
  2. import (
  3. "encoding/binary"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "math"
  8. "os"
  9. "path/filepath"
  10. "runtime"
  11. "sort"
  12. "strings"
  13. "sync"
  14. "makarna/pkg/loader"
  15. "makarna/pkg/quant"
  16. )
  17. type Options struct {
  18. BaseQuant quant.QuantType
  19. MixMode bool
  20. Workers int
  21. // MaxInFlightBytes bounds peak RAM by limiting concurrently processed tensors
  22. // (decoded float32 + output buffer). 0 = auto.
  23. MaxInFlightBytes uint64
  24. }
  25. func ConvertDirectory(modelPath string, outputPath string, opts Options) error {
  26. config, architecture, tieWordEmbeddings, err := readConfig(modelPath)
  27. if err != nil {
  28. return err
  29. }
  30. spec := NewSpec(architecture, tieWordEmbeddings, opts.BaseQuant, opts.MixMode)
  31. sfFiles := findSafetensorsFiles(modelPath)
  32. if len(sfFiles) == 0 {
  33. return fmt.Errorf("no safetensors files found in %s", modelPath)
  34. }
  35. writer, err := loader.NewWriter(outputPath)
  36. if err != nil {
  37. return fmt.Errorf("create output: %w", err)
  38. }
  39. ok := false
  40. defer func() {
  41. if ok {
  42. return
  43. }
  44. _ = writer.Close()
  45. _ = os.Remove(outputPath)
  46. }()
  47. writer.SetModelConfig(loader.ModelConfig{Architecture: architecture, Params: config})
  48. tokenizerPath := filepath.Join(modelPath, "tokenizer.json")
  49. if tokData, err := os.ReadFile(tokenizerPath); err == nil {
  50. writer.AddTokenizer(tokData)
  51. } else {
  52. altTokenizerPath := filepath.Join(modelPath, "gpt2_tokenizer", "tokenizer.json")
  53. if tokData, err := os.ReadFile(altTokenizerPath); err == nil {
  54. writer.AddTokenizer(tokData)
  55. }
  56. }
  57. workers := opts.Workers
  58. if workers <= 0 {
  59. workers = runtime.GOMAXPROCS(0)
  60. if workers <= 0 {
  61. workers = 1
  62. }
  63. }
  64. maxInFlight := opts.MaxInFlightBytes
  65. if maxInFlight == 0 {
  66. // Conservative default to keep conversion usable on machines with limited RAM.
  67. maxInFlight = 4 << 30 // 4 GiB
  68. }
  69. limiter := newMemLimiter(maxInFlight)
  70. for _, sfPath := range sfFiles {
  71. if err := convertSafetensorsFile(sfPath, writer, spec, workers, limiter); err != nil {
  72. return err
  73. }
  74. }
  75. if err := writer.Close(); err != nil {
  76. return err
  77. }
  78. ok = true
  79. return nil
  80. }
  81. type stTensor struct {
  82. DType string
  83. Shape []uint64
  84. Data []byte
  85. }
  86. type stHeaderTensor struct {
  87. DType string `json:"dtype"`
  88. Shape []uint64 `json:"shape"`
  89. DataOffsets []uint64 `json:"data_offsets"`
  90. }
  91. func readSafetensorsRaw(buf []byte) (map[string]stTensor, error) {
  92. if len(buf) < 8 {
  93. return nil, fmt.Errorf("invalid safetensors: too short")
  94. }
  95. headerLen := binary.LittleEndian.Uint64(buf[0:8])
  96. if headerLen > uint64(len(buf)-8) {
  97. return nil, fmt.Errorf("invalid safetensors: header too large")
  98. }
  99. headerStart := 8
  100. headerEnd := 8 + int(headerLen)
  101. dataStart := headerEnd
  102. var header map[string]stHeaderTensor
  103. if err := json.Unmarshal(buf[headerStart:headerEnd], &header); err != nil {
  104. return nil, fmt.Errorf("invalid safetensors header json: %w", err)
  105. }
  106. out := make(map[string]stTensor, len(header))
  107. for name, ti := range header {
  108. if name == "__metadata__" {
  109. continue
  110. }
  111. if len(ti.DataOffsets) != 2 {
  112. return nil, fmt.Errorf("tensor %s: invalid data_offsets", name)
  113. }
  114. o0 := ti.DataOffsets[0]
  115. o1 := ti.DataOffsets[1]
  116. if o1 < o0 {
  117. return nil, fmt.Errorf("tensor %s: invalid data_offsets range", name)
  118. }
  119. abs0 := dataStart + int(o0)
  120. abs1 := dataStart + int(o1)
  121. if abs0 < dataStart || abs1 < dataStart || abs1 > len(buf) {
  122. return nil, fmt.Errorf("tensor %s: data out of bounds", name)
  123. }
  124. out[name] = stTensor{DType: ti.DType, Shape: ti.Shape, Data: buf[abs0:abs1]}
  125. }
  126. return out, nil
  127. }
  128. func readConfig(modelPath string) (map[string]interface{}, string, bool, error) {
  129. configPath := filepath.Join(modelPath, "config.json")
  130. configData, err := os.ReadFile(configPath)
  131. if err != nil {
  132. return nil, "", false, fmt.Errorf("read config.json: %w", err)
  133. }
  134. var config map[string]interface{}
  135. if err := json.Unmarshal(configData, &config); err != nil {
  136. return nil, "", false, fmt.Errorf("parse config.json: %w", err)
  137. }
  138. architecture := "UnknownForCausalLM"
  139. if archs, ok := config["architectures"].([]interface{}); ok && len(archs) > 0 {
  140. if s, ok := archs[0].(string); ok {
  141. architecture = s
  142. }
  143. }
  144. tieWordEmbeddings := false
  145. if tie, ok := config["tie_word_embeddings"].(bool); ok {
  146. tieWordEmbeddings = tie
  147. }
  148. return config, architecture, tieWordEmbeddings, nil
  149. }
  150. func findSafetensorsFiles(modelPath string) []string {
  151. single := filepath.Join(modelPath, "model.safetensors")
  152. if _, err := os.Stat(single); err == nil {
  153. return []string{single}
  154. }
  155. indexPath := filepath.Join(modelPath, "model.safetensors.index.json")
  156. if indexData, err := os.ReadFile(indexPath); err == nil {
  157. var index struct {
  158. WeightMap map[string]string `json:"weight_map"`
  159. }
  160. if json.Unmarshal(indexData, &index) == nil {
  161. fileSet := make(map[string]bool)
  162. for _, f := range index.WeightMap {
  163. fileSet[f] = true
  164. }
  165. var files []string
  166. for f := range fileSet {
  167. files = append(files, filepath.Join(modelPath, f))
  168. }
  169. sort.Strings(files)
  170. return files
  171. }
  172. }
  173. pattern := filepath.Join(modelPath, "*.safetensors")
  174. files, _ := filepath.Glob(pattern)
  175. sort.Strings(files)
  176. return files
  177. }
  178. type stTensorRef struct {
  179. Name string
  180. DType string
  181. Shape []uint64
  182. Offset int64
  183. Size int64
  184. }
  185. type tensorResult struct {
  186. name string
  187. shape []uint64
  188. dtype loader.DType
  189. data []byte
  190. reader io.Reader
  191. size uint64
  192. err error
  193. release func()
  194. }
  195. func convertSafetensorsFile(sfPath string, writer *loader.Writer, spec *Spec, workers int, limiter *memLimiter) error {
  196. f, err := os.Open(sfPath)
  197. if err != nil {
  198. return fmt.Errorf("open %s: %w", sfPath, err)
  199. }
  200. defer func() {
  201. _ = f.Close()
  202. }()
  203. refs, err := listSafetensorsTensorRefs(f)
  204. if err != nil {
  205. return fmt.Errorf("parse %s: %w", sfPath, err)
  206. }
  207. sort.Slice(refs, func(i, j int) bool { return refs[i].Name < refs[j].Name })
  208. jobs := make(chan stTensorRef, workers*2)
  209. results := make(chan tensorResult, workers*2)
  210. var wg sync.WaitGroup
  211. workerFn := func() {
  212. defer wg.Done()
  213. for ref := range jobs {
  214. results <- convertSafetensorsTensor(f, ref, spec, limiter)
  215. }
  216. }
  217. wg.Add(workers)
  218. for i := 0; i < workers; i++ {
  219. go workerFn()
  220. }
  221. go func() {
  222. for _, ref := range refs {
  223. if spec.TieWordEmbeddings && strings.Contains(ref.Name, "lm_head") {
  224. continue
  225. }
  226. if spec.SkipTensor != nil && spec.SkipTensor(ref.Name) {
  227. continue
  228. }
  229. jobs <- ref
  230. }
  231. close(jobs)
  232. }()
  233. go func() {
  234. wg.Wait()
  235. close(results)
  236. }()
  237. var firstErr error
  238. processed := 0
  239. total := len(refs)
  240. for r := range results {
  241. if r.err != nil {
  242. if firstErr == nil {
  243. firstErr = r.err
  244. }
  245. if r.release != nil {
  246. r.release()
  247. }
  248. continue
  249. }
  250. if r.reader != nil {
  251. if err := writer.AddTensorFromReader(r.name, r.dtype, r.shape, r.reader, r.size); err != nil && firstErr == nil {
  252. firstErr = fmt.Errorf("write %s: %w", r.name, err)
  253. }
  254. if r.release != nil {
  255. r.release()
  256. }
  257. processed++
  258. if processed == 1 || processed == total || processed%256 == 0 {
  259. fmt.Printf("convert: %d/%d %s\n", processed, total, r.name)
  260. }
  261. continue
  262. }
  263. if err := writer.AddTensor(r.name, r.dtype, r.shape, r.data); err != nil && firstErr == nil {
  264. firstErr = fmt.Errorf("write %s: %w", r.name, err)
  265. }
  266. if r.release != nil {
  267. r.release()
  268. }
  269. processed++
  270. if processed == 1 || processed == total || processed%256 == 0 {
  271. fmt.Printf("convert: %d/%d %s\n", processed, total, r.name)
  272. }
  273. }
  274. return firstErr
  275. }
  276. func convertSafetensorsTensor(f *os.File, ref stTensorRef, spec *Spec, limiter *memLimiter) tensorResult {
  277. shapeInt := make([]int, len(ref.Shape))
  278. shapeIntOK := true
  279. maxInt := uint64(^uint(0) >> 1)
  280. for i, s := range ref.Shape {
  281. if s > maxInt {
  282. shapeIntOK = false
  283. break
  284. }
  285. shapeInt[i] = int(s)
  286. }
  287. isQuantizable := false
  288. if shapeIntOK && spec.IsQuantizable != nil {
  289. isQuantizable = spec.IsQuantizable(ref.Name, shapeInt, spec.BaseQuant)
  290. }
  291. qt := spec.BaseQuant
  292. if spec.ResolveQuant != nil {
  293. qt = spec.ResolveQuant(ref.Name, spec.BaseQuant)
  294. }
  295. shouldWriteF32 := !isQuantizable || qt == quant.TypeF32 || qt == quant.TypeF16 || qt == ""
  296. if shouldWriteF32 {
  297. switch strings.ToUpper(strings.TrimSpace(ref.DType)) {
  298. case "F32":
  299. return tensorResult{
  300. name: ref.Name,
  301. shape: ref.Shape,
  302. dtype: loader.F32,
  303. reader: io.NewSectionReader(f, ref.Offset, ref.Size),
  304. size: uint64(ref.Size),
  305. }
  306. case "F16":
  307. return tensorResult{
  308. name: ref.Name,
  309. shape: ref.Shape,
  310. dtype: loader.F16,
  311. reader: io.NewSectionReader(f, ref.Offset, ref.Size),
  312. size: uint64(ref.Size),
  313. }
  314. case "BF16":
  315. return tensorResult{
  316. name: ref.Name,
  317. shape: ref.Shape,
  318. dtype: loader.BF16,
  319. reader: io.NewSectionReader(f, ref.Offset, ref.Size),
  320. size: uint64(ref.Size),
  321. }
  322. }
  323. }
  324. elementsU64, err := tensorElementCountU64(ref.Shape)
  325. if err != nil {
  326. return tensorResult{name: ref.Name, err: fmt.Errorf("%s: %w", ref.Name, err)}
  327. }
  328. if elementsU64 > maxInt {
  329. return tensorResult{name: ref.Name, err: fmt.Errorf("%s: tensor too large", ref.Name)}
  330. }
  331. // If we quantize, pick a deterministic fallback upfront so memory estimation matches.
  332. if !shouldWriteF32 {
  333. switch qt {
  334. case quant.TypeQ8K, quant.TypeQ6K, quant.TypeQ5K, quant.TypeQ4K, quant.TypeQ3K, quant.TypeQ2K:
  335. // OK
  336. default:
  337. qt = quant.TypeQ4K
  338. }
  339. }
  340. memCost, err := estimateTensorMemory(elementsU64, shouldWriteF32, qt)
  341. if err != nil {
  342. return tensorResult{name: ref.Name, err: fmt.Errorf("%s: %w", ref.Name, err)}
  343. }
  344. release := limiter.Acquire(memCost)
  345. n := int(elementsU64)
  346. floats := make([]float32, n)
  347. if err := readTensorFloat32FromFile(f, ref, floats); err != nil {
  348. return tensorResult{name: ref.Name, err: fmt.Errorf("%s: %w", ref.Name, err), release: release}
  349. }
  350. if shouldWriteF32 {
  351. return tensorResult{name: ref.Name, shape: ref.Shape, dtype: loader.F32, data: float32ToBytes(floats), release: release}
  352. }
  353. var outData []byte
  354. switch qt {
  355. case quant.TypeQ8K:
  356. outData = quant.QuantizeQ8K(floats)
  357. case quant.TypeQ5K:
  358. outData = quant.QuantizeQ5K(floats)
  359. case quant.TypeQ6K:
  360. outData = quant.QuantizeQ6K(floats)
  361. case quant.TypeQ4K:
  362. outData = quant.QuantizeQ4K(floats)
  363. case quant.TypeQ3K:
  364. outData = quant.QuantizeQ3K(floats)
  365. case quant.TypeQ2K:
  366. outData = quant.QuantizeQ2K(floats)
  367. default:
  368. outData = quant.QuantizeQ4K(floats)
  369. qt = quant.TypeQ4K
  370. }
  371. return tensorResult{name: ref.Name, shape: ref.Shape, dtype: qt.ToDType(), data: outData, release: release}
  372. }
  373. const maxSafetensorsHeaderSize = 128 << 20 // 128 MiB
  374. func listSafetensorsTensorRefs(f *os.File) ([]stTensorRef, error) {
  375. header, dataStart, fileSize, err := readSafetensorsHeaderFromFile(f)
  376. if err != nil {
  377. return nil, err
  378. }
  379. out := make([]stTensorRef, 0, len(header))
  380. for name, ti := range header {
  381. if name == "__metadata__" {
  382. continue
  383. }
  384. if len(ti.DataOffsets) != 2 {
  385. return nil, fmt.Errorf("tensor %s: invalid data_offsets", name)
  386. }
  387. o0 := ti.DataOffsets[0]
  388. o1 := ti.DataOffsets[1]
  389. if o1 < o0 {
  390. return nil, fmt.Errorf("tensor %s: invalid data_offsets range", name)
  391. }
  392. abs0 := dataStart + int64(o0)
  393. abs1 := dataStart + int64(o1)
  394. if abs0 < dataStart || abs1 < dataStart || abs1 > fileSize {
  395. return nil, fmt.Errorf("tensor %s: data out of bounds", name)
  396. }
  397. out = append(out, stTensorRef{
  398. Name: name,
  399. DType: ti.DType,
  400. Shape: ti.Shape,
  401. Offset: abs0,
  402. Size: abs1 - abs0,
  403. })
  404. }
  405. return out, nil
  406. }
  407. func readSafetensorsHeaderFromFile(f *os.File) (map[string]stHeaderTensor, int64, int64, error) {
  408. fi, err := f.Stat()
  409. if err != nil {
  410. return nil, 0, 0, fmt.Errorf("stat safetensors: %w", err)
  411. }
  412. fileSize := fi.Size()
  413. if fileSize < 8 {
  414. return nil, 0, 0, fmt.Errorf("invalid safetensors: too short")
  415. }
  416. var headerLenBuf [8]byte
  417. if _, err := f.ReadAt(headerLenBuf[:], 0); err != nil {
  418. return nil, 0, 0, fmt.Errorf("read header len: %w", err)
  419. }
  420. headerLen := binary.LittleEndian.Uint64(headerLenBuf[:])
  421. if headerLen > maxSafetensorsHeaderSize {
  422. return nil, 0, 0, fmt.Errorf("invalid safetensors: header too large")
  423. }
  424. headerLenInt := int(headerLen)
  425. headerStart := int64(8)
  426. headerEnd := headerStart + int64(headerLen)
  427. if headerEnd > fileSize {
  428. return nil, 0, 0, fmt.Errorf("invalid safetensors: header out of bounds")
  429. }
  430. headerBytes := make([]byte, headerLenInt)
  431. if _, err := f.ReadAt(headerBytes, headerStart); err != nil {
  432. return nil, 0, 0, fmt.Errorf("read header: %w", err)
  433. }
  434. var header map[string]stHeaderTensor
  435. if err := json.Unmarshal(headerBytes, &header); err != nil {
  436. return nil, 0, 0, fmt.Errorf("invalid safetensors header json: %w", err)
  437. }
  438. dataStart := headerEnd
  439. return header, dataStart, fileSize, nil
  440. }
  441. func tensorElementCountU64(shape []uint64) (uint64, error) {
  442. if len(shape) == 0 {
  443. return 1, nil
  444. }
  445. total := uint64(1)
  446. maxU64 := ^uint64(0)
  447. for _, d := range shape {
  448. if d == 0 {
  449. return 0, nil
  450. }
  451. if total > maxU64/d {
  452. return 0, fmt.Errorf("tensor shape overflows element count")
  453. }
  454. total *= d
  455. }
  456. return total, nil
  457. }
  458. func tensorElementCount(shape []uint64) (int, error) {
  459. total, err := tensorElementCountU64(shape)
  460. if err != nil {
  461. return 0, err
  462. }
  463. maxInt := uint64(^uint(0) >> 1)
  464. if total > maxInt {
  465. return 0, fmt.Errorf("tensor too large")
  466. }
  467. return int(total), nil
  468. }
  469. type memLimiter struct {
  470. max uint64
  471. used uint64
  472. mu sync.Mutex
  473. cond *sync.Cond
  474. }
  475. func newMemLimiter(maxBytes uint64) *memLimiter {
  476. if maxBytes == 0 {
  477. return nil
  478. }
  479. l := &memLimiter{max: maxBytes}
  480. l.cond = sync.NewCond(&l.mu)
  481. return l
  482. }
  483. func (l *memLimiter) Acquire(bytes uint64) func() {
  484. if l == nil || l.max == 0 || bytes == 0 {
  485. return nil
  486. }
  487. if bytes > l.max {
  488. bytes = l.max
  489. }
  490. l.mu.Lock()
  491. for l.used+bytes > l.max {
  492. l.cond.Wait()
  493. }
  494. l.used += bytes
  495. l.mu.Unlock()
  496. return func() {
  497. l.mu.Lock()
  498. if bytes >= l.used {
  499. l.used = 0
  500. } else {
  501. l.used -= bytes
  502. }
  503. l.mu.Unlock()
  504. l.cond.Broadcast()
  505. }
  506. }
  507. func estimateTensorMemory(elements uint64, writeF32 bool, qt quant.QuantType) (uint64, error) {
  508. floatBytes, err := mulU64(elements, 4)
  509. if err != nil {
  510. return 0, err
  511. }
  512. if writeF32 {
  513. // We allocate both []float32 and a separate []byte buffer via float32ToBytes.
  514. return addU64(floatBytes, floatBytes)
  515. }
  516. outBytes, err := quantOutputBytes(qt, elements)
  517. if err != nil {
  518. return 0, err
  519. }
  520. return addU64(floatBytes, outBytes)
  521. }
  522. func quantOutputBytes(qt quant.QuantType, elements uint64) (uint64, error) {
  523. const block = 256
  524. blocks := (elements + block - 1) / block
  525. var bytesPerBlock uint64
  526. switch qt {
  527. case quant.TypeQ8K:
  528. bytesPerBlock = 292
  529. case quant.TypeQ6K:
  530. bytesPerBlock = 210
  531. case quant.TypeQ5K:
  532. bytesPerBlock = 176
  533. case quant.TypeQ4K:
  534. bytesPerBlock = 144
  535. case quant.TypeQ3K:
  536. bytesPerBlock = 110
  537. case quant.TypeQ2K:
  538. bytesPerBlock = 84
  539. default:
  540. return 0, fmt.Errorf("unsupported quant type: %s", qt)
  541. }
  542. if blocks > 0 && blocks > (^uint64(0))/bytesPerBlock {
  543. return 0, fmt.Errorf("quant output size overflows")
  544. }
  545. return blocks * bytesPerBlock, nil
  546. }
  547. func addU64(a, b uint64) (uint64, error) {
  548. if a > (^uint64(0))-b {
  549. return 0, fmt.Errorf("size overflows")
  550. }
  551. return a + b, nil
  552. }
  553. func mulU64(a, b uint64) (uint64, error) {
  554. if a == 0 || b == 0 {
  555. return 0, nil
  556. }
  557. if a > (^uint64(0))/b {
  558. return 0, fmt.Errorf("size overflows")
  559. }
  560. return a * b, nil
  561. }
  562. var stDecodeBufPool = sync.Pool{
  563. New: func() any {
  564. return make([]byte, 1<<20) // 1 MiB
  565. },
  566. }
  567. func readTensorFloat32FromFile(f *os.File, ref stTensorRef, dst []float32) error {
  568. dt := strings.ToUpper(strings.TrimSpace(ref.DType))
  569. switch dt {
  570. case "F32":
  571. if ref.Size != int64(len(dst))*4 {
  572. return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
  573. }
  574. return decodeF32(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
  575. case "F16":
  576. if ref.Size != int64(len(dst))*2 {
  577. return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
  578. }
  579. return decodeF16(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
  580. case "BF16":
  581. if ref.Size != int64(len(dst))*2 {
  582. return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
  583. }
  584. return decodeBF16(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
  585. case "F8_E4M3", "F8_E4M3FN", "FLOAT8_E4M3FN":
  586. if ref.Size != int64(len(dst)) {
  587. return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
  588. }
  589. return decodeF8E4M3(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
  590. case "F8_E5M2", "F8_E5M2FN", "FLOAT8_E5M2":
  591. if ref.Size != int64(len(dst)) {
  592. return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
  593. }
  594. return decodeF8E5M2(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
  595. default:
  596. return fmt.Errorf("unsupported dtype: %s", ref.DType)
  597. }
  598. }
  599. func decodeF32(r io.Reader, dst []float32) error {
  600. buf := stDecodeBufPool.Get().([]byte)
  601. defer stDecodeBufPool.Put(buf)
  602. i := 0
  603. for i < len(dst) {
  604. remaining := len(dst) - i
  605. n := remaining
  606. if n > len(buf)/4 {
  607. n = len(buf) / 4
  608. }
  609. b := buf[:n*4]
  610. if _, err := io.ReadFull(r, b); err != nil {
  611. return err
  612. }
  613. for j := 0; j < n; j++ {
  614. dst[i+j] = math.Float32frombits(binary.LittleEndian.Uint32(b[j*4 : j*4+4]))
  615. }
  616. i += n
  617. }
  618. return nil
  619. }
  620. func decodeF16(r io.Reader, dst []float32) error {
  621. buf := stDecodeBufPool.Get().([]byte)
  622. defer stDecodeBufPool.Put(buf)
  623. i := 0
  624. for i < len(dst) {
  625. remaining := len(dst) - i
  626. n := remaining
  627. if n > len(buf)/2 {
  628. n = len(buf) / 2
  629. }
  630. b := buf[:n*2]
  631. if _, err := io.ReadFull(r, b); err != nil {
  632. return err
  633. }
  634. for j := 0; j < n; j++ {
  635. dst[i+j] = float16ToFloat32(binary.LittleEndian.Uint16(b[j*2 : j*2+2]))
  636. }
  637. i += n
  638. }
  639. return nil
  640. }
  641. func decodeBF16(r io.Reader, dst []float32) error {
  642. buf := stDecodeBufPool.Get().([]byte)
  643. defer stDecodeBufPool.Put(buf)
  644. i := 0
  645. for i < len(dst) {
  646. remaining := len(dst) - i
  647. n := remaining
  648. if n > len(buf)/2 {
  649. n = len(buf) / 2
  650. }
  651. b := buf[:n*2]
  652. if _, err := io.ReadFull(r, b); err != nil {
  653. return err
  654. }
  655. for j := 0; j < n; j++ {
  656. dst[i+j] = bfloat16ToFloat32(binary.LittleEndian.Uint16(b[j*2 : j*2+2]))
  657. }
  658. i += n
  659. }
  660. return nil
  661. }
  662. func decodeF8E4M3(r io.Reader, dst []float32) error {
  663. buf := stDecodeBufPool.Get().([]byte)
  664. defer stDecodeBufPool.Put(buf)
  665. i := 0
  666. for i < len(dst) {
  667. remaining := len(dst) - i
  668. n := remaining
  669. if n > len(buf) {
  670. n = len(buf)
  671. }
  672. b := buf[:n]
  673. if _, err := io.ReadFull(r, b); err != nil {
  674. return err
  675. }
  676. for j := 0; j < n; j++ {
  677. dst[i+j] = float8E4M3ToFloat32(b[j])
  678. }
  679. i += n
  680. }
  681. return nil
  682. }
  683. func decodeF8E5M2(r io.Reader, dst []float32) error {
  684. buf := stDecodeBufPool.Get().([]byte)
  685. defer stDecodeBufPool.Put(buf)
  686. i := 0
  687. for i < len(dst) {
  688. remaining := len(dst) - i
  689. n := remaining
  690. if n > len(buf) {
  691. n = len(buf)
  692. }
  693. b := buf[:n]
  694. if _, err := io.ReadFull(r, b); err != nil {
  695. return err
  696. }
  697. for j := 0; j < n; j++ {
  698. dst[i+j] = float8E5M2ToFloat32(b[j])
  699. }
  700. i += n
  701. }
  702. return nil
  703. }
  704. func toFloat32(data []byte, dtype string) ([]float32, error) {
  705. dt := strings.ToUpper(strings.TrimSpace(dtype))
  706. switch dt {
  707. case "F32":
  708. n := len(data) / 4
  709. result := make([]float32, n)
  710. for i := 0; i < n; i++ {
  711. bits := binary.LittleEndian.Uint32(data[i*4 : i*4+4])
  712. result[i] = math.Float32frombits(bits)
  713. }
  714. return result, nil
  715. case "F16":
  716. n := len(data) / 2
  717. result := make([]float32, n)
  718. for i := 0; i < n; i++ {
  719. bits := binary.LittleEndian.Uint16(data[i*2 : i*2+2])
  720. result[i] = float16ToFloat32(bits)
  721. }
  722. return result, nil
  723. case "BF16":
  724. n := len(data) / 2
  725. result := make([]float32, n)
  726. for i := 0; i < n; i++ {
  727. bits := binary.LittleEndian.Uint16(data[i*2 : i*2+2])
  728. result[i] = bfloat16ToFloat32(bits)
  729. }
  730. return result, nil
  731. // Float8 formats (common in new checkpoints)
  732. case "F8_E4M3", "F8_E4M3FN", "FLOAT8_E4M3FN":
  733. n := len(data)
  734. result := make([]float32, n)
  735. for i := 0; i < n; i++ {
  736. result[i] = float8E4M3ToFloat32(data[i])
  737. }
  738. return result, nil
  739. case "F8_E5M2", "F8_E5M2FN", "FLOAT8_E5M2":
  740. n := len(data)
  741. result := make([]float32, n)
  742. for i := 0; i < n; i++ {
  743. result[i] = float8E5M2ToFloat32(data[i])
  744. }
  745. return result, nil
  746. default:
  747. return nil, fmt.Errorf("unsupported dtype: %s", dtype)
  748. }
  749. }
  750. func float8E4M3ToFloat32(b byte) float32 {
  751. sign := (b >> 7) & 0x1
  752. exp := (b >> 3) & 0xF
  753. mant := b & 0x7
  754. if exp == 0 {
  755. if mant == 0 {
  756. // +/- 0
  757. if sign == 1 {
  758. return float32(math.Copysign(0, -1))
  759. }
  760. return 0
  761. }
  762. // subnormal: (mant / 2^3) * 2^(1-bias)
  763. v := float32(mant) / 8.0
  764. v = float32(math.Ldexp(float64(v), 1-7))
  765. if sign == 1 {
  766. v = -v
  767. }
  768. return v
  769. }
  770. if exp == 0xF {
  771. // treat as inf/nan
  772. if mant == 0 {
  773. if sign == 1 {
  774. return float32(math.Inf(-1))
  775. }
  776. return float32(math.Inf(1))
  777. }
  778. return float32(math.NaN())
  779. }
  780. // normal: (1 + mant/2^3) * 2^(exp-bias)
  781. frac := 1.0 + float64(mant)/8.0
  782. v := math.Ldexp(frac, int(exp)-7)
  783. if sign == 1 {
  784. v = -v
  785. }
  786. return float32(v)
  787. }
  788. func float8E5M2ToFloat32(b byte) float32 {
  789. sign := (b >> 7) & 0x1
  790. exp := (b >> 2) & 0x1F
  791. mant := b & 0x3
  792. if exp == 0 {
  793. if mant == 0 {
  794. if sign == 1 {
  795. return float32(math.Copysign(0, -1))
  796. }
  797. return 0
  798. }
  799. // subnormal: (mant / 2^2) * 2^(1-bias)
  800. v := float32(mant) / 4.0
  801. v = float32(math.Ldexp(float64(v), 1-15))
  802. if sign == 1 {
  803. v = -v
  804. }
  805. return v
  806. }
  807. if exp == 0x1F {
  808. if mant == 0 {
  809. if sign == 1 {
  810. return float32(math.Inf(-1))
  811. }
  812. return float32(math.Inf(1))
  813. }
  814. return float32(math.NaN())
  815. }
  816. frac := 1.0 + float64(mant)/4.0
  817. v := math.Ldexp(frac, int(exp)-15)
  818. if sign == 1 {
  819. v = -v
  820. }
  821. return float32(v)
  822. }
  823. func float16ToFloat32(bits uint16) float32 {
  824. sign := uint32(bits&0x8000) << 16
  825. exp := uint32(bits&0x7C00) >> 10
  826. mant := uint32(bits & 0x03FF)
  827. if exp == 0 {
  828. if mant == 0 {
  829. return math.Float32frombits(sign)
  830. }
  831. for mant&0x0400 == 0 {
  832. mant <<= 1
  833. exp--
  834. }
  835. exp++
  836. mant &= 0x03FF
  837. } else if exp == 0x1F {
  838. if mant == 0 {
  839. return math.Float32frombits(sign | 0x7F800000)
  840. }
  841. return math.Float32frombits(sign | 0x7FC00000)
  842. }
  843. exp += 127 - 15
  844. return math.Float32frombits(sign | (exp << 23) | (mant << 13))
  845. }
  846. func bfloat16ToFloat32(bits uint16) float32 {
  847. return math.Float32frombits(uint32(bits) << 16)
  848. }
  849. func float32ToBytes(floats []float32) []byte {
  850. data := make([]byte, len(floats)*4)
  851. for i, f := range floats {
  852. bits := math.Float32bits(f)
  853. binary.LittleEndian.PutUint32(data[i*4:], bits)
  854. }
  855. return data
  856. }