| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962 |
- package convert
- import (
- "encoding/binary"
- "encoding/json"
- "fmt"
- "io"
- "math"
- "os"
- "path/filepath"
- "runtime"
- "sort"
- "strings"
- "sync"
- "makarna/pkg/loader"
- "makarna/pkg/quant"
- )
- type Options struct {
- BaseQuant quant.QuantType
- MixMode bool
- Workers int
- // MaxInFlightBytes bounds peak RAM by limiting concurrently processed tensors
- // (decoded float32 + output buffer). 0 = auto.
- MaxInFlightBytes uint64
- }
- func ConvertDirectory(modelPath string, outputPath string, opts Options) error {
- config, architecture, tieWordEmbeddings, err := readConfig(modelPath)
- if err != nil {
- return err
- }
- spec := NewSpec(architecture, tieWordEmbeddings, opts.BaseQuant, opts.MixMode)
- sfFiles := findSafetensorsFiles(modelPath)
- if len(sfFiles) == 0 {
- return fmt.Errorf("no safetensors files found in %s", modelPath)
- }
- writer, err := loader.NewWriter(outputPath)
- if err != nil {
- return fmt.Errorf("create output: %w", err)
- }
- ok := false
- defer func() {
- if ok {
- return
- }
- _ = writer.Close()
- _ = os.Remove(outputPath)
- }()
- writer.SetModelConfig(loader.ModelConfig{Architecture: architecture, Params: config})
- tokenizerPath := filepath.Join(modelPath, "tokenizer.json")
- if tokData, err := os.ReadFile(tokenizerPath); err == nil {
- writer.AddTokenizer(tokData)
- } else {
- altTokenizerPath := filepath.Join(modelPath, "gpt2_tokenizer", "tokenizer.json")
- if tokData, err := os.ReadFile(altTokenizerPath); err == nil {
- writer.AddTokenizer(tokData)
- }
- }
- workers := opts.Workers
- if workers <= 0 {
- workers = runtime.GOMAXPROCS(0)
- if workers <= 0 {
- workers = 1
- }
- }
- maxInFlight := opts.MaxInFlightBytes
- if maxInFlight == 0 {
- // Conservative default to keep conversion usable on machines with limited RAM.
- maxInFlight = 4 << 30 // 4 GiB
- }
- limiter := newMemLimiter(maxInFlight)
- for _, sfPath := range sfFiles {
- if err := convertSafetensorsFile(sfPath, writer, spec, workers, limiter); err != nil {
- return err
- }
- }
- if err := writer.Close(); err != nil {
- return err
- }
- ok = true
- return nil
- }
- type stTensor struct {
- DType string
- Shape []uint64
- Data []byte
- }
- type stHeaderTensor struct {
- DType string `json:"dtype"`
- Shape []uint64 `json:"shape"`
- DataOffsets []uint64 `json:"data_offsets"`
- }
- func readSafetensorsRaw(buf []byte) (map[string]stTensor, error) {
- if len(buf) < 8 {
- return nil, fmt.Errorf("invalid safetensors: too short")
- }
- headerLen := binary.LittleEndian.Uint64(buf[0:8])
- if headerLen > uint64(len(buf)-8) {
- return nil, fmt.Errorf("invalid safetensors: header too large")
- }
- headerStart := 8
- headerEnd := 8 + int(headerLen)
- dataStart := headerEnd
- var header map[string]stHeaderTensor
- if err := json.Unmarshal(buf[headerStart:headerEnd], &header); err != nil {
- return nil, fmt.Errorf("invalid safetensors header json: %w", err)
- }
- out := make(map[string]stTensor, len(header))
- for name, ti := range header {
- if name == "__metadata__" {
- continue
- }
- if len(ti.DataOffsets) != 2 {
- return nil, fmt.Errorf("tensor %s: invalid data_offsets", name)
- }
- o0 := ti.DataOffsets[0]
- o1 := ti.DataOffsets[1]
- if o1 < o0 {
- return nil, fmt.Errorf("tensor %s: invalid data_offsets range", name)
- }
- abs0 := dataStart + int(o0)
- abs1 := dataStart + int(o1)
- if abs0 < dataStart || abs1 < dataStart || abs1 > len(buf) {
- return nil, fmt.Errorf("tensor %s: data out of bounds", name)
- }
- out[name] = stTensor{DType: ti.DType, Shape: ti.Shape, Data: buf[abs0:abs1]}
- }
- return out, nil
- }
- func readConfig(modelPath string) (map[string]interface{}, string, bool, error) {
- configPath := filepath.Join(modelPath, "config.json")
- configData, err := os.ReadFile(configPath)
- if err != nil {
- return nil, "", false, fmt.Errorf("read config.json: %w", err)
- }
- var config map[string]interface{}
- if err := json.Unmarshal(configData, &config); err != nil {
- return nil, "", false, fmt.Errorf("parse config.json: %w", err)
- }
- architecture := "UnknownForCausalLM"
- if archs, ok := config["architectures"].([]interface{}); ok && len(archs) > 0 {
- if s, ok := archs[0].(string); ok {
- architecture = s
- }
- }
- tieWordEmbeddings := false
- if tie, ok := config["tie_word_embeddings"].(bool); ok {
- tieWordEmbeddings = tie
- }
- return config, architecture, tieWordEmbeddings, nil
- }
- func findSafetensorsFiles(modelPath string) []string {
- single := filepath.Join(modelPath, "model.safetensors")
- if _, err := os.Stat(single); err == nil {
- return []string{single}
- }
- indexPath := filepath.Join(modelPath, "model.safetensors.index.json")
- if indexData, err := os.ReadFile(indexPath); err == nil {
- var index struct {
- WeightMap map[string]string `json:"weight_map"`
- }
- if json.Unmarshal(indexData, &index) == nil {
- fileSet := make(map[string]bool)
- for _, f := range index.WeightMap {
- fileSet[f] = true
- }
- var files []string
- for f := range fileSet {
- files = append(files, filepath.Join(modelPath, f))
- }
- sort.Strings(files)
- return files
- }
- }
- pattern := filepath.Join(modelPath, "*.safetensors")
- files, _ := filepath.Glob(pattern)
- sort.Strings(files)
- return files
- }
- type stTensorRef struct {
- Name string
- DType string
- Shape []uint64
- Offset int64
- Size int64
- }
- type tensorResult struct {
- name string
- shape []uint64
- dtype loader.DType
- data []byte
- reader io.Reader
- size uint64
- err error
- release func()
- }
- func convertSafetensorsFile(sfPath string, writer *loader.Writer, spec *Spec, workers int, limiter *memLimiter) error {
- f, err := os.Open(sfPath)
- if err != nil {
- return fmt.Errorf("open %s: %w", sfPath, err)
- }
- defer func() {
- _ = f.Close()
- }()
- refs, err := listSafetensorsTensorRefs(f)
- if err != nil {
- return fmt.Errorf("parse %s: %w", sfPath, err)
- }
- sort.Slice(refs, func(i, j int) bool { return refs[i].Name < refs[j].Name })
- jobs := make(chan stTensorRef, workers*2)
- results := make(chan tensorResult, workers*2)
- var wg sync.WaitGroup
- workerFn := func() {
- defer wg.Done()
- for ref := range jobs {
- results <- convertSafetensorsTensor(f, ref, spec, limiter)
- }
- }
- wg.Add(workers)
- for i := 0; i < workers; i++ {
- go workerFn()
- }
- go func() {
- for _, ref := range refs {
- if spec.TieWordEmbeddings && strings.Contains(ref.Name, "lm_head") {
- continue
- }
- if spec.SkipTensor != nil && spec.SkipTensor(ref.Name) {
- continue
- }
- jobs <- ref
- }
- close(jobs)
- }()
- go func() {
- wg.Wait()
- close(results)
- }()
- var firstErr error
- processed := 0
- total := len(refs)
- for r := range results {
- if r.err != nil {
- if firstErr == nil {
- firstErr = r.err
- }
- if r.release != nil {
- r.release()
- }
- continue
- }
- if r.reader != nil {
- if err := writer.AddTensorFromReader(r.name, r.dtype, r.shape, r.reader, r.size); err != nil && firstErr == nil {
- firstErr = fmt.Errorf("write %s: %w", r.name, err)
- }
- if r.release != nil {
- r.release()
- }
- processed++
- if processed == 1 || processed == total || processed%256 == 0 {
- fmt.Printf("convert: %d/%d %s\n", processed, total, r.name)
- }
- continue
- }
- if err := writer.AddTensor(r.name, r.dtype, r.shape, r.data); err != nil && firstErr == nil {
- firstErr = fmt.Errorf("write %s: %w", r.name, err)
- }
- if r.release != nil {
- r.release()
- }
- processed++
- if processed == 1 || processed == total || processed%256 == 0 {
- fmt.Printf("convert: %d/%d %s\n", processed, total, r.name)
- }
- }
- return firstErr
- }
- func convertSafetensorsTensor(f *os.File, ref stTensorRef, spec *Spec, limiter *memLimiter) tensorResult {
- shapeInt := make([]int, len(ref.Shape))
- shapeIntOK := true
- maxInt := uint64(^uint(0) >> 1)
- for i, s := range ref.Shape {
- if s > maxInt {
- shapeIntOK = false
- break
- }
- shapeInt[i] = int(s)
- }
- isQuantizable := false
- if shapeIntOK && spec.IsQuantizable != nil {
- isQuantizable = spec.IsQuantizable(ref.Name, shapeInt, spec.BaseQuant)
- }
- qt := spec.BaseQuant
- if spec.ResolveQuant != nil {
- qt = spec.ResolveQuant(ref.Name, spec.BaseQuant)
- }
- shouldWriteF32 := !isQuantizable || qt == quant.TypeF32 || qt == quant.TypeF16 || qt == ""
- if shouldWriteF32 {
- switch strings.ToUpper(strings.TrimSpace(ref.DType)) {
- case "F32":
- return tensorResult{
- name: ref.Name,
- shape: ref.Shape,
- dtype: loader.F32,
- reader: io.NewSectionReader(f, ref.Offset, ref.Size),
- size: uint64(ref.Size),
- }
- case "F16":
- return tensorResult{
- name: ref.Name,
- shape: ref.Shape,
- dtype: loader.F16,
- reader: io.NewSectionReader(f, ref.Offset, ref.Size),
- size: uint64(ref.Size),
- }
- case "BF16":
- return tensorResult{
- name: ref.Name,
- shape: ref.Shape,
- dtype: loader.BF16,
- reader: io.NewSectionReader(f, ref.Offset, ref.Size),
- size: uint64(ref.Size),
- }
- }
- }
- elementsU64, err := tensorElementCountU64(ref.Shape)
- if err != nil {
- return tensorResult{name: ref.Name, err: fmt.Errorf("%s: %w", ref.Name, err)}
- }
- if elementsU64 > maxInt {
- return tensorResult{name: ref.Name, err: fmt.Errorf("%s: tensor too large", ref.Name)}
- }
- // If we quantize, pick a deterministic fallback upfront so memory estimation matches.
- if !shouldWriteF32 {
- switch qt {
- case quant.TypeQ8K, quant.TypeQ6K, quant.TypeQ5K, quant.TypeQ4K, quant.TypeQ3K, quant.TypeQ2K:
- // OK
- default:
- qt = quant.TypeQ4K
- }
- }
- memCost, err := estimateTensorMemory(elementsU64, shouldWriteF32, qt)
- if err != nil {
- return tensorResult{name: ref.Name, err: fmt.Errorf("%s: %w", ref.Name, err)}
- }
- release := limiter.Acquire(memCost)
- n := int(elementsU64)
- floats := make([]float32, n)
- if err := readTensorFloat32FromFile(f, ref, floats); err != nil {
- return tensorResult{name: ref.Name, err: fmt.Errorf("%s: %w", ref.Name, err), release: release}
- }
- if shouldWriteF32 {
- return tensorResult{name: ref.Name, shape: ref.Shape, dtype: loader.F32, data: float32ToBytes(floats), release: release}
- }
- var outData []byte
- switch qt {
- case quant.TypeQ8K:
- outData = quant.QuantizeQ8K(floats)
- case quant.TypeQ5K:
- outData = quant.QuantizeQ5K(floats)
- case quant.TypeQ6K:
- outData = quant.QuantizeQ6K(floats)
- case quant.TypeQ4K:
- outData = quant.QuantizeQ4K(floats)
- case quant.TypeQ3K:
- outData = quant.QuantizeQ3K(floats)
- case quant.TypeQ2K:
- outData = quant.QuantizeQ2K(floats)
- default:
- outData = quant.QuantizeQ4K(floats)
- qt = quant.TypeQ4K
- }
- return tensorResult{name: ref.Name, shape: ref.Shape, dtype: qt.ToDType(), data: outData, release: release}
- }
- const maxSafetensorsHeaderSize = 128 << 20 // 128 MiB
- func listSafetensorsTensorRefs(f *os.File) ([]stTensorRef, error) {
- header, dataStart, fileSize, err := readSafetensorsHeaderFromFile(f)
- if err != nil {
- return nil, err
- }
- out := make([]stTensorRef, 0, len(header))
- for name, ti := range header {
- if name == "__metadata__" {
- continue
- }
- if len(ti.DataOffsets) != 2 {
- return nil, fmt.Errorf("tensor %s: invalid data_offsets", name)
- }
- o0 := ti.DataOffsets[0]
- o1 := ti.DataOffsets[1]
- if o1 < o0 {
- return nil, fmt.Errorf("tensor %s: invalid data_offsets range", name)
- }
- abs0 := dataStart + int64(o0)
- abs1 := dataStart + int64(o1)
- if abs0 < dataStart || abs1 < dataStart || abs1 > fileSize {
- return nil, fmt.Errorf("tensor %s: data out of bounds", name)
- }
- out = append(out, stTensorRef{
- Name: name,
- DType: ti.DType,
- Shape: ti.Shape,
- Offset: abs0,
- Size: abs1 - abs0,
- })
- }
- return out, nil
- }
- func readSafetensorsHeaderFromFile(f *os.File) (map[string]stHeaderTensor, int64, int64, error) {
- fi, err := f.Stat()
- if err != nil {
- return nil, 0, 0, fmt.Errorf("stat safetensors: %w", err)
- }
- fileSize := fi.Size()
- if fileSize < 8 {
- return nil, 0, 0, fmt.Errorf("invalid safetensors: too short")
- }
- var headerLenBuf [8]byte
- if _, err := f.ReadAt(headerLenBuf[:], 0); err != nil {
- return nil, 0, 0, fmt.Errorf("read header len: %w", err)
- }
- headerLen := binary.LittleEndian.Uint64(headerLenBuf[:])
- if headerLen > maxSafetensorsHeaderSize {
- return nil, 0, 0, fmt.Errorf("invalid safetensors: header too large")
- }
- headerLenInt := int(headerLen)
- headerStart := int64(8)
- headerEnd := headerStart + int64(headerLen)
- if headerEnd > fileSize {
- return nil, 0, 0, fmt.Errorf("invalid safetensors: header out of bounds")
- }
- headerBytes := make([]byte, headerLenInt)
- if _, err := f.ReadAt(headerBytes, headerStart); err != nil {
- return nil, 0, 0, fmt.Errorf("read header: %w", err)
- }
- var header map[string]stHeaderTensor
- if err := json.Unmarshal(headerBytes, &header); err != nil {
- return nil, 0, 0, fmt.Errorf("invalid safetensors header json: %w", err)
- }
- dataStart := headerEnd
- return header, dataStart, fileSize, nil
- }
- func tensorElementCountU64(shape []uint64) (uint64, error) {
- if len(shape) == 0 {
- return 1, nil
- }
- total := uint64(1)
- maxU64 := ^uint64(0)
- for _, d := range shape {
- if d == 0 {
- return 0, nil
- }
- if total > maxU64/d {
- return 0, fmt.Errorf("tensor shape overflows element count")
- }
- total *= d
- }
- return total, nil
- }
- func tensorElementCount(shape []uint64) (int, error) {
- total, err := tensorElementCountU64(shape)
- if err != nil {
- return 0, err
- }
- maxInt := uint64(^uint(0) >> 1)
- if total > maxInt {
- return 0, fmt.Errorf("tensor too large")
- }
- return int(total), nil
- }
- type memLimiter struct {
- max uint64
- used uint64
- mu sync.Mutex
- cond *sync.Cond
- }
- func newMemLimiter(maxBytes uint64) *memLimiter {
- if maxBytes == 0 {
- return nil
- }
- l := &memLimiter{max: maxBytes}
- l.cond = sync.NewCond(&l.mu)
- return l
- }
- func (l *memLimiter) Acquire(bytes uint64) func() {
- if l == nil || l.max == 0 || bytes == 0 {
- return nil
- }
- if bytes > l.max {
- bytes = l.max
- }
- l.mu.Lock()
- for l.used+bytes > l.max {
- l.cond.Wait()
- }
- l.used += bytes
- l.mu.Unlock()
- return func() {
- l.mu.Lock()
- if bytes >= l.used {
- l.used = 0
- } else {
- l.used -= bytes
- }
- l.mu.Unlock()
- l.cond.Broadcast()
- }
- }
- func estimateTensorMemory(elements uint64, writeF32 bool, qt quant.QuantType) (uint64, error) {
- floatBytes, err := mulU64(elements, 4)
- if err != nil {
- return 0, err
- }
- if writeF32 {
- // We allocate both []float32 and a separate []byte buffer via float32ToBytes.
- return addU64(floatBytes, floatBytes)
- }
- outBytes, err := quantOutputBytes(qt, elements)
- if err != nil {
- return 0, err
- }
- return addU64(floatBytes, outBytes)
- }
- func quantOutputBytes(qt quant.QuantType, elements uint64) (uint64, error) {
- const block = 256
- blocks := (elements + block - 1) / block
- var bytesPerBlock uint64
- switch qt {
- case quant.TypeQ8K:
- bytesPerBlock = 292
- case quant.TypeQ6K:
- bytesPerBlock = 210
- case quant.TypeQ5K:
- bytesPerBlock = 176
- case quant.TypeQ4K:
- bytesPerBlock = 144
- case quant.TypeQ3K:
- bytesPerBlock = 110
- case quant.TypeQ2K:
- bytesPerBlock = 84
- default:
- return 0, fmt.Errorf("unsupported quant type: %s", qt)
- }
- if blocks > 0 && blocks > (^uint64(0))/bytesPerBlock {
- return 0, fmt.Errorf("quant output size overflows")
- }
- return blocks * bytesPerBlock, nil
- }
- func addU64(a, b uint64) (uint64, error) {
- if a > (^uint64(0))-b {
- return 0, fmt.Errorf("size overflows")
- }
- return a + b, nil
- }
- func mulU64(a, b uint64) (uint64, error) {
- if a == 0 || b == 0 {
- return 0, nil
- }
- if a > (^uint64(0))/b {
- return 0, fmt.Errorf("size overflows")
- }
- return a * b, nil
- }
- var stDecodeBufPool = sync.Pool{
- New: func() any {
- return make([]byte, 1<<20) // 1 MiB
- },
- }
- func readTensorFloat32FromFile(f *os.File, ref stTensorRef, dst []float32) error {
- dt := strings.ToUpper(strings.TrimSpace(ref.DType))
- switch dt {
- case "F32":
- if ref.Size != int64(len(dst))*4 {
- return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
- }
- return decodeF32(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
- case "F16":
- if ref.Size != int64(len(dst))*2 {
- return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
- }
- return decodeF16(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
- case "BF16":
- if ref.Size != int64(len(dst))*2 {
- return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
- }
- return decodeBF16(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
- case "F8_E4M3", "F8_E4M3FN", "FLOAT8_E4M3FN":
- if ref.Size != int64(len(dst)) {
- return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
- }
- return decodeF8E4M3(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
- case "F8_E5M2", "F8_E5M2FN", "FLOAT8_E5M2":
- if ref.Size != int64(len(dst)) {
- return fmt.Errorf("unexpected byte size for %s (%s)", ref.Name, ref.DType)
- }
- return decodeF8E5M2(io.NewSectionReader(f, ref.Offset, ref.Size), dst)
- default:
- return fmt.Errorf("unsupported dtype: %s", ref.DType)
- }
- }
- func decodeF32(r io.Reader, dst []float32) error {
- buf := stDecodeBufPool.Get().([]byte)
- defer stDecodeBufPool.Put(buf)
- i := 0
- for i < len(dst) {
- remaining := len(dst) - i
- n := remaining
- if n > len(buf)/4 {
- n = len(buf) / 4
- }
- b := buf[:n*4]
- if _, err := io.ReadFull(r, b); err != nil {
- return err
- }
- for j := 0; j < n; j++ {
- dst[i+j] = math.Float32frombits(binary.LittleEndian.Uint32(b[j*4 : j*4+4]))
- }
- i += n
- }
- return nil
- }
- func decodeF16(r io.Reader, dst []float32) error {
- buf := stDecodeBufPool.Get().([]byte)
- defer stDecodeBufPool.Put(buf)
- i := 0
- for i < len(dst) {
- remaining := len(dst) - i
- n := remaining
- if n > len(buf)/2 {
- n = len(buf) / 2
- }
- b := buf[:n*2]
- if _, err := io.ReadFull(r, b); err != nil {
- return err
- }
- for j := 0; j < n; j++ {
- dst[i+j] = float16ToFloat32(binary.LittleEndian.Uint16(b[j*2 : j*2+2]))
- }
- i += n
- }
- return nil
- }
- func decodeBF16(r io.Reader, dst []float32) error {
- buf := stDecodeBufPool.Get().([]byte)
- defer stDecodeBufPool.Put(buf)
- i := 0
- for i < len(dst) {
- remaining := len(dst) - i
- n := remaining
- if n > len(buf)/2 {
- n = len(buf) / 2
- }
- b := buf[:n*2]
- if _, err := io.ReadFull(r, b); err != nil {
- return err
- }
- for j := 0; j < n; j++ {
- dst[i+j] = bfloat16ToFloat32(binary.LittleEndian.Uint16(b[j*2 : j*2+2]))
- }
- i += n
- }
- return nil
- }
- func decodeF8E4M3(r io.Reader, dst []float32) error {
- buf := stDecodeBufPool.Get().([]byte)
- defer stDecodeBufPool.Put(buf)
- i := 0
- for i < len(dst) {
- remaining := len(dst) - i
- n := remaining
- if n > len(buf) {
- n = len(buf)
- }
- b := buf[:n]
- if _, err := io.ReadFull(r, b); err != nil {
- return err
- }
- for j := 0; j < n; j++ {
- dst[i+j] = float8E4M3ToFloat32(b[j])
- }
- i += n
- }
- return nil
- }
- func decodeF8E5M2(r io.Reader, dst []float32) error {
- buf := stDecodeBufPool.Get().([]byte)
- defer stDecodeBufPool.Put(buf)
- i := 0
- for i < len(dst) {
- remaining := len(dst) - i
- n := remaining
- if n > len(buf) {
- n = len(buf)
- }
- b := buf[:n]
- if _, err := io.ReadFull(r, b); err != nil {
- return err
- }
- for j := 0; j < n; j++ {
- dst[i+j] = float8E5M2ToFloat32(b[j])
- }
- i += n
- }
- return nil
- }
- func toFloat32(data []byte, dtype string) ([]float32, error) {
- dt := strings.ToUpper(strings.TrimSpace(dtype))
- switch dt {
- case "F32":
- n := len(data) / 4
- result := make([]float32, n)
- for i := 0; i < n; i++ {
- bits := binary.LittleEndian.Uint32(data[i*4 : i*4+4])
- result[i] = math.Float32frombits(bits)
- }
- return result, nil
- case "F16":
- n := len(data) / 2
- result := make([]float32, n)
- for i := 0; i < n; i++ {
- bits := binary.LittleEndian.Uint16(data[i*2 : i*2+2])
- result[i] = float16ToFloat32(bits)
- }
- return result, nil
- case "BF16":
- n := len(data) / 2
- result := make([]float32, n)
- for i := 0; i < n; i++ {
- bits := binary.LittleEndian.Uint16(data[i*2 : i*2+2])
- result[i] = bfloat16ToFloat32(bits)
- }
- return result, nil
- // Float8 formats (common in new checkpoints)
- case "F8_E4M3", "F8_E4M3FN", "FLOAT8_E4M3FN":
- n := len(data)
- result := make([]float32, n)
- for i := 0; i < n; i++ {
- result[i] = float8E4M3ToFloat32(data[i])
- }
- return result, nil
- case "F8_E5M2", "F8_E5M2FN", "FLOAT8_E5M2":
- n := len(data)
- result := make([]float32, n)
- for i := 0; i < n; i++ {
- result[i] = float8E5M2ToFloat32(data[i])
- }
- return result, nil
- default:
- return nil, fmt.Errorf("unsupported dtype: %s", dtype)
- }
- }
- func float8E4M3ToFloat32(b byte) float32 {
- sign := (b >> 7) & 0x1
- exp := (b >> 3) & 0xF
- mant := b & 0x7
- if exp == 0 {
- if mant == 0 {
- // +/- 0
- if sign == 1 {
- return float32(math.Copysign(0, -1))
- }
- return 0
- }
- // subnormal: (mant / 2^3) * 2^(1-bias)
- v := float32(mant) / 8.0
- v = float32(math.Ldexp(float64(v), 1-7))
- if sign == 1 {
- v = -v
- }
- return v
- }
- if exp == 0xF {
- // treat as inf/nan
- if mant == 0 {
- if sign == 1 {
- return float32(math.Inf(-1))
- }
- return float32(math.Inf(1))
- }
- return float32(math.NaN())
- }
- // normal: (1 + mant/2^3) * 2^(exp-bias)
- frac := 1.0 + float64(mant)/8.0
- v := math.Ldexp(frac, int(exp)-7)
- if sign == 1 {
- v = -v
- }
- return float32(v)
- }
- func float8E5M2ToFloat32(b byte) float32 {
- sign := (b >> 7) & 0x1
- exp := (b >> 2) & 0x1F
- mant := b & 0x3
- if exp == 0 {
- if mant == 0 {
- if sign == 1 {
- return float32(math.Copysign(0, -1))
- }
- return 0
- }
- // subnormal: (mant / 2^2) * 2^(1-bias)
- v := float32(mant) / 4.0
- v = float32(math.Ldexp(float64(v), 1-15))
- if sign == 1 {
- v = -v
- }
- return v
- }
- if exp == 0x1F {
- if mant == 0 {
- if sign == 1 {
- return float32(math.Inf(-1))
- }
- return float32(math.Inf(1))
- }
- return float32(math.NaN())
- }
- frac := 1.0 + float64(mant)/4.0
- v := math.Ldexp(frac, int(exp)-15)
- if sign == 1 {
- v = -v
- }
- return float32(v)
- }
- func float16ToFloat32(bits uint16) float32 {
- sign := uint32(bits&0x8000) << 16
- exp := uint32(bits&0x7C00) >> 10
- mant := uint32(bits & 0x03FF)
- if exp == 0 {
- if mant == 0 {
- return math.Float32frombits(sign)
- }
- for mant&0x0400 == 0 {
- mant <<= 1
- exp--
- }
- exp++
- mant &= 0x03FF
- } else if exp == 0x1F {
- if mant == 0 {
- return math.Float32frombits(sign | 0x7F800000)
- }
- return math.Float32frombits(sign | 0x7FC00000)
- }
- exp += 127 - 15
- return math.Float32frombits(sign | (exp << 23) | (mant << 13))
- }
- func bfloat16ToFloat32(bits uint16) float32 {
- return math.Float32frombits(uint32(bits) << 16)
- }
- func float32ToBytes(floats []float32) []byte {
- data := make([]byte, len(floats)*4)
- for i, f := range floats {
- bits := math.Float32bits(f)
- binary.LittleEndian.PutUint32(data[i*4:], bits)
- }
- return data
- }
|