| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- package loader
- import (
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "os"
- "syscall"
- )
- // ModelData represents a loaded model file
- type ModelData struct {
- Header Header
- Sections *SectionTable
- Metadata Metadata
- TensorIndex []TensorEntry
-
- // data is the mmap'd file
- data []byte
- file *os.File
- size uint64
- }
- type LoadOptions struct {
- // UseMmap controls whether the model file is memory-mapped.
- // When false, tensors are read from the file into heap allocations (RAM).
- UseMmap bool
- }
- // Load opens a .mak file and mmaps it (legacy behavior).
- func Load(path string) (*ModelData, error) {
- return LoadWithOptions(path, LoadOptions{UseMmap: true})
- }
- // LoadWithOptions opens a .mak file with configurable I/O strategy.
- func LoadWithOptions(path string, opts LoadOptions) (*ModelData, error) {
- f, err := os.Open(path)
- if err != nil {
- return nil, fmt.Errorf("failed to open file: %w", err)
- }
- fi, err := f.Stat()
- if err != nil {
- f.Close()
- return nil, fmt.Errorf("failed to stat file: %w", err)
- }
- size := fi.Size()
- md := &ModelData{
- file: f,
- size: uint64(size),
- }
- if opts.UseMmap {
- // Mmap the file
- data, err := syscall.Mmap(int(f.Fd()), 0, int(size), syscall.PROT_READ, syscall.MAP_SHARED)
- if err != nil {
- f.Close()
- return nil, fmt.Errorf("failed to mmap file: %w", err)
- }
- md.data = data
- }
- if err := md.parse(); err != nil {
- md.Close()
- return nil, err
- }
- return md, nil
- }
- func (md *ModelData) Close() error {
- if md.data != nil {
- syscall.Munmap(md.data)
- }
- if md.file != nil {
- return md.file.Close()
- }
- return nil
- }
- func (md *ModelData) parse() error {
- if md.data != nil {
- if len(md.data) < HeaderSize {
- return fmt.Errorf("file too small")
- }
- // Read header
- reader := bytes.NewReader(md.data[:HeaderSize])
- h, err := ReadHeader(reader)
- if err != nil {
- return err
- }
- md.Header = *h
- md.size = uint64(len(md.data))
- // Read section table
- sectionTableStart := HeaderSize
- sectionTableEnd := sectionTableStart + int(h.SectionCount)*SectionEntrySize
- if sectionTableEnd > len(md.data) {
- return fmt.Errorf("section table out of bounds")
- }
- sectionReader := bytes.NewReader(md.data[sectionTableStart:sectionTableEnd])
- sections, err := ReadSectionTable(sectionReader, h.SectionCount)
- if err != nil {
- return err
- }
- md.Sections = sections
- // Parse metadata section
- metaSection, err := sections.FindSection(SectionMetadata)
- if err != nil {
- return fmt.Errorf("metadata section not found: %w", err)
- }
- metaBytes := md.data[metaSection.Offset : metaSection.Offset+metaSection.Size]
- if err := json.Unmarshal(metaBytes, &md.Metadata); err != nil {
- return fmt.Errorf("failed to unmarshal metadata: %w", err)
- }
- return nil
- }
- if md.file == nil {
- return fmt.Errorf("no file handle")
- }
- if md.size < HeaderSize {
- return fmt.Errorf("file too small")
- }
- readExactAt := func(off uint64, dst []byte) error {
- if off+uint64(len(dst)) > md.size {
- return fmt.Errorf("read out of bounds: off=%d size=%d file=%d", off, len(dst), md.size)
- }
- n, err := md.file.ReadAt(dst, int64(off))
- if n != len(dst) {
- if err == nil {
- err = io.ErrUnexpectedEOF
- }
- return err
- }
- return nil
- }
- // Read header
- hdr := make([]byte, HeaderSize)
- if err := readExactAt(0, hdr); err != nil {
- return fmt.Errorf("read header: %w", err)
- }
- h, err := ReadHeader(bytes.NewReader(hdr))
- if err != nil {
- return err
- }
- md.Header = *h
- // Read section table
- stSize := uint64(h.SectionCount) * uint64(SectionEntrySize)
- st := make([]byte, stSize)
- if err := readExactAt(uint64(HeaderSize), st); err != nil {
- return fmt.Errorf("read section table: %w", err)
- }
- sections, err := ReadSectionTable(bytes.NewReader(st), h.SectionCount)
- if err != nil {
- return err
- }
- md.Sections = sections
- // Parse metadata section
- metaSection, err := sections.FindSection(SectionMetadata)
- if err != nil {
- return fmt.Errorf("metadata section not found: %w", err)
- }
- if metaSection.Offset+metaSection.Size > md.size {
- return fmt.Errorf("metadata section out of bounds")
- }
- metaBytes := make([]byte, metaSection.Size)
- if err := readExactAt(metaSection.Offset, metaBytes); err != nil {
- return fmt.Errorf("read metadata section: %w", err)
- }
- if err := json.Unmarshal(metaBytes, &md.Metadata); err != nil {
- return fmt.Errorf("failed to unmarshal metadata: %w", err)
- }
- return nil
- }
- // GetTensorData returns the byte slice for a named tensor
- func (md *ModelData) GetTensorData(name string) ([]byte, error) {
- // Find tensor data section for offset calculation
- dataSection, err := md.Sections.FindSection(SectionTensorData)
- if err != nil {
- return nil, fmt.Errorf("tensor data section not found: %w", err)
- }
-
- // Check tensors in metadata
- if info, ok := md.Metadata.Tensors[name]; ok {
- // info.Offset is relative to tensor data section
- absOffset := dataSection.Offset + info.Offset
- if absOffset+info.Size > md.size {
- return nil, fmt.Errorf("tensor %s data out of bounds", name)
- }
- if md.data != nil {
- return md.data[absOffset : absOffset+info.Size], nil
- }
- if info.Size > uint64(int(^uint(0)>>1)) {
- return nil, fmt.Errorf("tensor %s too large: %d bytes", name, info.Size)
- }
- buf := make([]byte, info.Size)
- n, err := md.file.ReadAt(buf, int64(absOffset))
- if n != len(buf) {
- if err == nil {
- err = io.ErrUnexpectedEOF
- }
- return nil, fmt.Errorf("read tensor %s: %w", name, err)
- }
- return buf, nil
- }
-
- // Check tensor index
- for _, entry := range md.TensorIndex {
- if entry.Name == name {
- // Find tensor data section
- dataSection, err := md.Sections.FindSection(SectionTensorData)
- if err != nil {
- return nil, err
- }
-
- absOffset := dataSection.Offset + entry.Offset
- if absOffset+entry.Size > md.size {
- return nil, fmt.Errorf("tensor %s data out of bounds", name)
- }
- if md.data != nil {
- return md.data[absOffset : absOffset+entry.Size], nil
- }
- if entry.Size > uint64(int(^uint(0)>>1)) {
- return nil, fmt.Errorf("tensor %s too large: %d bytes", name, entry.Size)
- }
- buf := make([]byte, entry.Size)
- n, err := md.file.ReadAt(buf, int64(absOffset))
- if n != len(buf) {
- if err == nil {
- err = io.ErrUnexpectedEOF
- }
- return nil, fmt.Errorf("read tensor %s: %w", name, err)
- }
- return buf, nil
- }
- }
-
- return nil, fmt.Errorf("tensor not found: %s", name)
- }
- // GetTokenizerData returns the embedded tokenizer JSON data if present
- func (md *ModelData) GetTokenizerData() ([]byte, error) {
- section, err := md.Sections.FindSection(SectionTokenizer)
- if err != nil {
- return nil, err // Section not found
- }
-
- if section.Offset+section.Size > md.size {
- return nil, fmt.Errorf("tokenizer data out of bounds")
- }
- if md.data != nil {
- return md.data[section.Offset : section.Offset+section.Size], nil
- }
- if section.Size > uint64(int(^uint(0)>>1)) {
- return nil, fmt.Errorf("tokenizer section too large: %d bytes", section.Size)
- }
- buf := make([]byte, section.Size)
- n, err := md.file.ReadAt(buf, int64(section.Offset))
- if n != len(buf) {
- if err == nil {
- err = io.ErrUnexpectedEOF
- }
- return nil, fmt.Errorf("read tokenizer: %w", err)
- }
- return buf, nil
- }
|