| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- package loader
- import (
- "encoding/json"
- "fmt"
- "io"
- "os"
- )
- // Writer helps create .mak files with section-based format
- type Writer struct {
- file *os.File
- header Header
- metadata Metadata
- tensors []TensorEntry
- tensorDataPos uint64 // absolute file offset of tensor data section start
- tensorOffset uint64 // relative offset within tensor data section
- tokenizerData []byte
- closed bool
- }
- // NewWriter creates a new MAK file writer
- func NewWriter(path string) (*Writer, error) {
- f, err := os.Create(path)
- if err != nil {
- return nil, err
- }
- w := &Writer{
- file: f,
- header: Header{
- Magic: [4]byte{'M', 'A', 'K', 'A'},
- Version: Version,
- Flags: FlagLittleEndian,
- Alignment: Alignment,
- },
- metadata: Metadata{
- Tensors: make(map[string]TensorEntry),
- },
- tensors: make([]TensorEntry, 0),
- }
- // Reserve space for header + section table. We keep this fixed so we can stream
- // tensor data immediately, then seek back and write the final header/table.
- //
- // Max sections: metadata + tensor data + tokenizer.
- const maxSections = uint32(3)
- prefixSize := uint64(HeaderSize + int(maxSections)*SectionEntrySize)
- w.tensorDataPos = alignOffset(prefixSize, uint64(Alignment))
- if _, err := f.Seek(int64(w.tensorDataPos), io.SeekStart); err != nil {
- _ = f.Close()
- return nil, fmt.Errorf("seek tensor data start: %w", err)
- }
- return w, nil
- }
- // SetModelConfig sets the model configuration
- func (w *Writer) SetModelConfig(cfg ModelConfig) {
- w.metadata.ModelConfig = cfg
- }
- // AddTokenizer adds tokenizer data to the file
- func (w *Writer) AddTokenizer(data []byte) {
- w.tokenizerData = data
- }
- // AddTensor adds a tensor to the file
- func (w *Writer) AddTensor(name string, dtype DType, shape []uint64, data []byte) error {
- if w.closed || w.file == nil {
- return fmt.Errorf("writer is closed")
- }
- // Align to alignment boundary
- padding := (uint64(Alignment) - (w.tensorOffset % uint64(Alignment))) % uint64(Alignment)
- if padding > 0 {
- if err := writeZeros(w.file, padding); err != nil {
- return fmt.Errorf("write tensor padding: %w", err)
- }
- w.tensorOffset += padding
- }
- entry := TensorEntry{
- Name: name,
- DType: dtype,
- Shape: shape,
- Offset: w.tensorOffset,
- Size: uint64(len(data)),
- }
- w.tensors = append(w.tensors, entry)
- w.metadata.Tensors[name] = entry
- if err := writeAll(w.file, data); err != nil {
- return fmt.Errorf("write tensor %s: %w", name, err)
- }
- w.tensorOffset += entry.Size
- return nil
- }
- // AddTensorFromReader streams tensor data from r into the file without keeping it in memory.
- func (w *Writer) AddTensorFromReader(name string, dtype DType, shape []uint64, r io.Reader, size uint64) error {
- if w.closed || w.file == nil {
- return fmt.Errorf("writer is closed")
- }
- // Align to alignment boundary
- padding := (uint64(Alignment) - (w.tensorOffset % uint64(Alignment))) % uint64(Alignment)
- if padding > 0 {
- if err := writeZeros(w.file, padding); err != nil {
- return fmt.Errorf("write tensor padding: %w", err)
- }
- w.tensorOffset += padding
- }
- entry := TensorEntry{
- Name: name,
- DType: dtype,
- Shape: shape,
- Offset: w.tensorOffset,
- Size: size,
- }
- w.tensors = append(w.tensors, entry)
- w.metadata.Tensors[name] = entry
- if err := copyExactly(w.file, r, int64(size)); err != nil {
- return fmt.Errorf("write tensor %s: %w", name, err)
- }
- w.tensorOffset += size
- return nil
- }
- // Close finalizes and closes the file
- func (w *Writer) Close() error {
- if w.closed {
- return nil
- }
- w.closed = true
- if w.file == nil {
- return nil
- }
- // Build sections
- // Section 1: Metadata JSON
- metaJSON, err := json.Marshal(w.metadata)
- if err != nil {
- _ = w.file.Close()
- return fmt.Errorf("marshal metadata: %w", err)
- }
- // Calculate offsets
- numSections := uint32(2) // TensorData + Metadata
- if len(w.tokenizerData) > 0 {
- numSections = 3 // + Tokenizer
- }
- tensorDataOffset := w.tensorDataPos
- tensorDataSize := w.tensorOffset
- metaOffset := alignOffset(tensorDataOffset+tensorDataSize, uint64(Alignment))
- metaSize := uint64(len(metaJSON))
- // Build section table
- sections := &SectionTable{
- Entries: []SectionEntry{
- {Type: SectionTensorData, Offset: tensorDataOffset, Size: tensorDataSize, Flags: 0},
- {Type: SectionMetadata, Offset: metaOffset, Size: metaSize, Flags: 0},
- },
- }
- // Write metadata after tensor data (streaming-friendly layout).
- curPos := tensorDataOffset + tensorDataSize
- if metaOffset < curPos {
- _ = w.file.Close()
- return fmt.Errorf("invalid metadata offset")
- }
- if err := writeZeros(w.file, metaOffset-curPos); err != nil {
- _ = w.file.Close()
- return fmt.Errorf("write metadata padding: %w", err)
- }
- if err := writeAll(w.file, metaJSON); err != nil {
- _ = w.file.Close()
- return fmt.Errorf("write metadata: %w", err)
- }
- // Add tokenizer section if present
- var tokenizerOffset uint64
- if len(w.tokenizerData) > 0 {
- tokenizerOffset = alignOffset(metaOffset+metaSize, uint64(Alignment))
- sections.Entries = append(sections.Entries, SectionEntry{
- Type: SectionTokenizer,
- Offset: tokenizerOffset,
- Size: uint64(len(w.tokenizerData)),
- Flags: 0,
- })
- curPos = metaOffset + metaSize
- if tokenizerOffset < curPos {
- _ = w.file.Close()
- return fmt.Errorf("invalid tokenizer offset")
- }
- if err := writeZeros(w.file, tokenizerOffset-curPos); err != nil {
- _ = w.file.Close()
- return fmt.Errorf("write tokenizer padding: %w", err)
- }
- if err := writeAll(w.file, w.tokenizerData); err != nil {
- _ = w.file.Close()
- return fmt.Errorf("write tokenizer: %w", err)
- }
- }
- // Update header
- w.header.SectionCount = numSections
- // Seek back and write header + section table.
- if _, err := w.file.Seek(0, io.SeekStart); err != nil {
- _ = w.file.Close()
- return fmt.Errorf("seek header: %w", err)
- }
- if err := WriteHeader(w.file, &w.header); err != nil {
- _ = w.file.Close()
- return fmt.Errorf("write header: %w", err)
- }
- if err := WriteSectionTable(w.file, sections); err != nil {
- _ = w.file.Close()
- return fmt.Errorf("write section table: %w", err)
- }
- return w.file.Close()
- }
- func alignOffset(offset, alignment uint64) uint64 {
- return offset + (alignment-(offset%alignment))%alignment
- }
- var zeroBuf = make([]byte, Alignment)
- func writeZeros(w io.Writer, n uint64) error {
- for n > 0 {
- chunk := n
- if chunk > uint64(len(zeroBuf)) {
- chunk = uint64(len(zeroBuf))
- }
- if err := writeAll(w, zeroBuf[:chunk]); err != nil {
- return err
- }
- n -= chunk
- }
- return nil
- }
- func writeAll(w io.Writer, b []byte) error {
- for len(b) > 0 {
- n, err := w.Write(b)
- if err != nil {
- return err
- }
- b = b[n:]
- }
- return nil
- }
- func copyExactly(dst io.Writer, src io.Reader, n int64) error {
- written, err := io.CopyN(dst, src, n)
- if err != nil {
- return err
- }
- if written != n {
- return fmt.Errorf("short write: %d != %d", written, n)
- }
- return nil
- }
|