package cpu import ( "fmt" "unsafe" "makarna/pkg/tensor" ) // Tensor implements tensor.Tensor for CPU type Tensor struct { shape tensor.Shape dtype tensor.DType dataFloat32 []float32 dataUint16 []uint16 dataQ4_K []tensor.BlockQ4_K dataQ3_K []tensor.BlockQ3_K dataQ5_K []tensor.BlockQ5_K dataQ6_K []tensor.BlockQ6_K dataQ8_K []tensor.BlockQ8_K dataQ2_K []tensor.BlockQ2_K } func (t *Tensor) Placement() tensor.DevicePlacement { return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1} } // NewTensor creates a new Float32 tensor func NewTensor(shape tensor.Shape, data []float32) *Tensor { if data == nil { data = make([]float32, shape.NumElements()) } return &Tensor{ shape: shape, dtype: tensor.Float32, dataFloat32: data, } } func NewTensorU16(shape tensor.Shape, dtype tensor.DType, data []uint16) (*Tensor, error) { if dtype != tensor.Float16 && dtype != tensor.BFloat16 { return nil, fmt.Errorf("unsupported u16 tensor dtype: %v", dtype) } if data == nil { data = make([]uint16, shape.NumElements()) } if len(data) != shape.NumElements() { return nil, fmt.Errorf("size mismatch for u16 tensor: expected %d, got %d", shape.NumElements(), len(data)) } return &Tensor{shape: shape, dtype: dtype, dataUint16: data}, nil } // NewTensorFromBytes creates a tensor from raw bytes (zero-copy mmap) func NewTensorFromBytes(shape tensor.Shape, dtype tensor.DType, data []byte) (*Tensor, error) { if len(data) == 0 { return nil, fmt.Errorf("empty data for tensor") } ptr := unsafe.Pointer(&data[0]) t := &Tensor{ shape: shape, dtype: dtype, } switch dtype { case tensor.Float32: expectedSize := shape.NumElements() * 4 if len(data) != expectedSize { return nil, fmt.Errorf("size mismatch for F32 tensor: expected %d bytes, got %d", expectedSize, len(data)) } // Zero-copy cast t.dataFloat32 = unsafe.Slice((*float32)(ptr), shape.NumElements()) case tensor.Float16, tensor.BFloat16: expectedSize := shape.NumElements() * 2 if len(data) != expectedSize { return nil, fmt.Errorf("size mismatch for %v tensor: expected %d bytes, got %d", dtype, expectedSize, len(data)) } t.dataUint16 = unsafe.Slice((*uint16)(ptr), shape.NumElements()) case tensor.Q4_K: const blockSize = 256 const blockBytes = 144 // 2(D)+2(DMin)+12(scales)+128(qs) elemCount := shape.NumElements() if elemCount%blockSize != 0 { return nil, fmt.Errorf("Q4_K tensor elements must be multiple of %d", blockSize) } numBlocks := elemCount / blockSize expectedSize := numBlocks * blockBytes if len(data) != expectedSize { return nil, fmt.Errorf("size mismatch for Q4_K tensor: expected %d bytes, got %d", expectedSize, len(data)) } t.dataQ4_K = unsafe.Slice((*tensor.BlockQ4_K)(ptr), numBlocks) case tensor.Q8_K: const blockSize = 256 const blockBytes = 292 // 4(D)+256(qs)+32(bsums) elemCount := shape.NumElements() if elemCount%blockSize != 0 { return nil, fmt.Errorf("Q8_K tensor elements must be multiple of %d", blockSize) } numBlocks := elemCount / blockSize expectedSize := numBlocks * blockBytes if len(data) != expectedSize { return nil, fmt.Errorf("size mismatch for Q8_K tensor: expected %d bytes, got %d", expectedSize, len(data)) } t.dataQ8_K = unsafe.Slice((*tensor.BlockQ8_K)(ptr), numBlocks) case tensor.Q3_K: const blockSize = 256 const blockBytes = 110 // 32(hmask)+64(qs)+12(scales)+2(D) elemCount := shape.NumElements() if elemCount%blockSize != 0 { return nil, fmt.Errorf("Q3_K tensor elements must be multiple of %d", blockSize) } numBlocks := elemCount / blockSize expectedSize := numBlocks * blockBytes if len(data) != expectedSize { return nil, fmt.Errorf("size mismatch for Q3_K tensor: expected %d bytes, got %d", expectedSize, len(data)) } t.dataQ3_K = unsafe.Slice((*tensor.BlockQ3_K)(ptr), numBlocks) case tensor.Q5_K: const blockSize = 256 const blockBytes = 176 elemCount := shape.NumElements() if elemCount%blockSize != 0 { return nil, fmt.Errorf("Q5_K tensor elements must be multiple of %d", blockSize) } numBlocks := elemCount / blockSize expectedSize := numBlocks * blockBytes if len(data) != expectedSize { return nil, fmt.Errorf("size mismatch for Q5_K tensor: expected %d bytes, got %d", expectedSize, len(data)) } t.dataQ5_K = unsafe.Slice((*tensor.BlockQ5_K)(ptr), numBlocks) case tensor.Q6_K: const blockSize = 256 const blockBytes = 210 // 128(ql)+64(qh)+16(scales)+2(D) elemCount := shape.NumElements() if elemCount%blockSize != 0 { return nil, fmt.Errorf("Q6_K tensor elements must be multiple of %d", blockSize) } numBlocks := elemCount / blockSize expectedSize := numBlocks * blockBytes if len(data) != expectedSize { return nil, fmt.Errorf("size mismatch for Q6_K tensor: expected %d bytes, got %d", expectedSize, len(data)) } t.dataQ6_K = unsafe.Slice((*tensor.BlockQ6_K)(ptr), numBlocks) case tensor.Q2_K: const blockSize = 256 const blockBytes = 84 // 16(scales)+64(qs)+2(D)+2(DMin) elemCount := shape.NumElements() if elemCount%blockSize != 0 { return nil, fmt.Errorf("Q2_K tensor elements must be multiple of %d", blockSize) } numBlocks := elemCount / blockSize expectedSize := numBlocks * blockBytes if len(data) != expectedSize { return nil, fmt.Errorf("size mismatch for Q2_K tensor: expected %d bytes, got %d", expectedSize, len(data)) } t.dataQ2_K = unsafe.Slice((*tensor.BlockQ2_K)(ptr), numBlocks) default: return nil, fmt.Errorf("unsupported tensor dtype: %v", dtype) } return t, nil } // Shape returns tensor dimensions func (t *Tensor) Shape() tensor.Shape { return t.shape } // DType returns data type func (t *Tensor) DType() tensor.DType { return t.dtype } // Device returns CPU func (t *Tensor) Device() tensor.DeviceType { return tensor.CPU } // Data returns raw data pointer. // Use DataFloat32/DataQ4_K etc for type-safe access. func (t *Tensor) Data() interface{} { switch t.dtype { case tensor.Float32: return unsafe.Pointer(&t.dataFloat32[0]) case tensor.Float16, tensor.BFloat16: return unsafe.Pointer(&t.dataUint16[0]) case tensor.Q4_K: return unsafe.Pointer(&t.dataQ4_K[0]) case tensor.Q3_K: return unsafe.Pointer(&t.dataQ3_K[0]) case tensor.Q5_K: return unsafe.Pointer(&t.dataQ5_K[0]) case tensor.Q6_K: return unsafe.Pointer(&t.dataQ6_K[0]) case tensor.Q8_K: return unsafe.Pointer(&t.dataQ8_K[0]) case tensor.Q2_K: return unsafe.Pointer(&t.dataQ2_K[0]) default: panic(fmt.Sprintf("internal error: unsupported dtype %v in Data()", t.dtype)) } } // DataFloat32 returns the underlying float32 slice directly func (t *Tensor) DataFloat32() []float32 { return t.dataFloat32 } func (t *Tensor) DataUint16() []uint16 { return t.dataUint16 } // DataQ4_K returns the underlying Q4_K block slice func (t *Tensor) DataQ4_K() []tensor.BlockQ4_K { return t.dataQ4_K } // DataQ3_K returns the underlying Q3_K block slice func (t *Tensor) DataQ3_K() []tensor.BlockQ3_K { return t.dataQ3_K } func (t *Tensor) DataQ5_K() []tensor.BlockQ5_K { return t.dataQ5_K } // DataQ6_K returns the underlying Q6_K block slice func (t *Tensor) DataQ6_K() []tensor.BlockQ6_K { return t.dataQ6_K } // DataQ8_K returns the underlying Q8_K block slice func (t *Tensor) DataQ8_K() []tensor.BlockQ8_K { return t.dataQ8_K } // DataQ2_K returns the underlying Q2_K block slice func (t *Tensor) DataQ2_K() []tensor.BlockQ2_K { return t.dataQ2_K }