tensor.go 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. package cpu
  2. import (
  3. "fmt"
  4. "unsafe"
  5. "makarna/pkg/tensor"
  6. )
  7. // Tensor implements tensor.Tensor for CPU
  8. type Tensor struct {
  9. shape tensor.Shape
  10. dtype tensor.DType
  11. dataFloat32 []float32
  12. dataUint16 []uint16
  13. dataQ4_K []tensor.BlockQ4_K
  14. dataQ3_K []tensor.BlockQ3_K
  15. dataQ5_K []tensor.BlockQ5_K
  16. dataQ6_K []tensor.BlockQ6_K
  17. dataQ8_K []tensor.BlockQ8_K
  18. dataQ2_K []tensor.BlockQ2_K
  19. }
  20. func (t *Tensor) Placement() tensor.DevicePlacement {
  21. return tensor.DevicePlacement{Type: tensor.CPU, GPU: -1}
  22. }
  23. // NewTensor creates a new Float32 tensor
  24. func NewTensor(shape tensor.Shape, data []float32) *Tensor {
  25. if data == nil {
  26. data = make([]float32, shape.NumElements())
  27. }
  28. return &Tensor{
  29. shape: shape,
  30. dtype: tensor.Float32,
  31. dataFloat32: data,
  32. }
  33. }
  34. func NewTensorU16(shape tensor.Shape, dtype tensor.DType, data []uint16) (*Tensor, error) {
  35. if dtype != tensor.Float16 && dtype != tensor.BFloat16 {
  36. return nil, fmt.Errorf("unsupported u16 tensor dtype: %v", dtype)
  37. }
  38. if data == nil {
  39. data = make([]uint16, shape.NumElements())
  40. }
  41. if len(data) != shape.NumElements() {
  42. return nil, fmt.Errorf("size mismatch for u16 tensor: expected %d, got %d", shape.NumElements(), len(data))
  43. }
  44. return &Tensor{shape: shape, dtype: dtype, dataUint16: data}, nil
  45. }
  46. // NewTensorFromBytes creates a tensor from raw bytes (zero-copy mmap)
  47. func NewTensorFromBytes(shape tensor.Shape, dtype tensor.DType, data []byte) (*Tensor, error) {
  48. if len(data) == 0 {
  49. return nil, fmt.Errorf("empty data for tensor")
  50. }
  51. ptr := unsafe.Pointer(&data[0])
  52. t := &Tensor{
  53. shape: shape,
  54. dtype: dtype,
  55. }
  56. switch dtype {
  57. case tensor.Float32:
  58. expectedSize := shape.NumElements() * 4
  59. if len(data) != expectedSize {
  60. return nil, fmt.Errorf("size mismatch for F32 tensor: expected %d bytes, got %d", expectedSize, len(data))
  61. }
  62. // Zero-copy cast
  63. t.dataFloat32 = unsafe.Slice((*float32)(ptr), shape.NumElements())
  64. case tensor.Float16, tensor.BFloat16:
  65. expectedSize := shape.NumElements() * 2
  66. if len(data) != expectedSize {
  67. return nil, fmt.Errorf("size mismatch for %v tensor: expected %d bytes, got %d", dtype, expectedSize, len(data))
  68. }
  69. t.dataUint16 = unsafe.Slice((*uint16)(ptr), shape.NumElements())
  70. case tensor.Q4_K:
  71. const blockSize = 256
  72. const blockBytes = 144 // 2(D)+2(DMin)+12(scales)+128(qs)
  73. elemCount := shape.NumElements()
  74. if elemCount%blockSize != 0 {
  75. return nil, fmt.Errorf("Q4_K tensor elements must be multiple of %d", blockSize)
  76. }
  77. numBlocks := elemCount / blockSize
  78. expectedSize := numBlocks * blockBytes
  79. if len(data) != expectedSize {
  80. return nil, fmt.Errorf("size mismatch for Q4_K tensor: expected %d bytes, got %d", expectedSize, len(data))
  81. }
  82. t.dataQ4_K = unsafe.Slice((*tensor.BlockQ4_K)(ptr), numBlocks)
  83. case tensor.Q8_K:
  84. const blockSize = 256
  85. const blockBytes = 292 // 4(D)+256(qs)+32(bsums)
  86. elemCount := shape.NumElements()
  87. if elemCount%blockSize != 0 {
  88. return nil, fmt.Errorf("Q8_K tensor elements must be multiple of %d", blockSize)
  89. }
  90. numBlocks := elemCount / blockSize
  91. expectedSize := numBlocks * blockBytes
  92. if len(data) != expectedSize {
  93. return nil, fmt.Errorf("size mismatch for Q8_K tensor: expected %d bytes, got %d", expectedSize, len(data))
  94. }
  95. t.dataQ8_K = unsafe.Slice((*tensor.BlockQ8_K)(ptr), numBlocks)
  96. case tensor.Q3_K:
  97. const blockSize = 256
  98. const blockBytes = 110 // 32(hmask)+64(qs)+12(scales)+2(D)
  99. elemCount := shape.NumElements()
  100. if elemCount%blockSize != 0 {
  101. return nil, fmt.Errorf("Q3_K tensor elements must be multiple of %d", blockSize)
  102. }
  103. numBlocks := elemCount / blockSize
  104. expectedSize := numBlocks * blockBytes
  105. if len(data) != expectedSize {
  106. return nil, fmt.Errorf("size mismatch for Q3_K tensor: expected %d bytes, got %d", expectedSize, len(data))
  107. }
  108. t.dataQ3_K = unsafe.Slice((*tensor.BlockQ3_K)(ptr), numBlocks)
  109. case tensor.Q5_K:
  110. const blockSize = 256
  111. const blockBytes = 176
  112. elemCount := shape.NumElements()
  113. if elemCount%blockSize != 0 {
  114. return nil, fmt.Errorf("Q5_K tensor elements must be multiple of %d", blockSize)
  115. }
  116. numBlocks := elemCount / blockSize
  117. expectedSize := numBlocks * blockBytes
  118. if len(data) != expectedSize {
  119. return nil, fmt.Errorf("size mismatch for Q5_K tensor: expected %d bytes, got %d", expectedSize, len(data))
  120. }
  121. t.dataQ5_K = unsafe.Slice((*tensor.BlockQ5_K)(ptr), numBlocks)
  122. case tensor.Q6_K:
  123. const blockSize = 256
  124. const blockBytes = 210 // 128(ql)+64(qh)+16(scales)+2(D)
  125. elemCount := shape.NumElements()
  126. if elemCount%blockSize != 0 {
  127. return nil, fmt.Errorf("Q6_K tensor elements must be multiple of %d", blockSize)
  128. }
  129. numBlocks := elemCount / blockSize
  130. expectedSize := numBlocks * blockBytes
  131. if len(data) != expectedSize {
  132. return nil, fmt.Errorf("size mismatch for Q6_K tensor: expected %d bytes, got %d", expectedSize, len(data))
  133. }
  134. t.dataQ6_K = unsafe.Slice((*tensor.BlockQ6_K)(ptr), numBlocks)
  135. case tensor.Q2_K:
  136. const blockSize = 256
  137. const blockBytes = 84 // 16(scales)+64(qs)+2(D)+2(DMin)
  138. elemCount := shape.NumElements()
  139. if elemCount%blockSize != 0 {
  140. return nil, fmt.Errorf("Q2_K tensor elements must be multiple of %d", blockSize)
  141. }
  142. numBlocks := elemCount / blockSize
  143. expectedSize := numBlocks * blockBytes
  144. if len(data) != expectedSize {
  145. return nil, fmt.Errorf("size mismatch for Q2_K tensor: expected %d bytes, got %d", expectedSize, len(data))
  146. }
  147. t.dataQ2_K = unsafe.Slice((*tensor.BlockQ2_K)(ptr), numBlocks)
  148. default:
  149. return nil, fmt.Errorf("unsupported tensor dtype: %v", dtype)
  150. }
  151. return t, nil
  152. }
  153. // Shape returns tensor dimensions
  154. func (t *Tensor) Shape() tensor.Shape {
  155. return t.shape
  156. }
  157. // DType returns data type
  158. func (t *Tensor) DType() tensor.DType {
  159. return t.dtype
  160. }
  161. // Device returns CPU
  162. func (t *Tensor) Device() tensor.DeviceType {
  163. return tensor.CPU
  164. }
  165. // Data returns raw data pointer.
  166. // Use DataFloat32/DataQ4_K etc for type-safe access.
  167. func (t *Tensor) Data() interface{} {
  168. switch t.dtype {
  169. case tensor.Float32:
  170. return unsafe.Pointer(&t.dataFloat32[0])
  171. case tensor.Float16, tensor.BFloat16:
  172. return unsafe.Pointer(&t.dataUint16[0])
  173. case tensor.Q4_K:
  174. return unsafe.Pointer(&t.dataQ4_K[0])
  175. case tensor.Q3_K:
  176. return unsafe.Pointer(&t.dataQ3_K[0])
  177. case tensor.Q5_K:
  178. return unsafe.Pointer(&t.dataQ5_K[0])
  179. case tensor.Q6_K:
  180. return unsafe.Pointer(&t.dataQ6_K[0])
  181. case tensor.Q8_K:
  182. return unsafe.Pointer(&t.dataQ8_K[0])
  183. case tensor.Q2_K:
  184. return unsafe.Pointer(&t.dataQ2_K[0])
  185. default:
  186. panic(fmt.Sprintf("internal error: unsupported dtype %v in Data()", t.dtype))
  187. }
  188. }
  189. // DataFloat32 returns the underlying float32 slice directly
  190. func (t *Tensor) DataFloat32() []float32 {
  191. return t.dataFloat32
  192. }
  193. func (t *Tensor) DataUint16() []uint16 {
  194. return t.dataUint16
  195. }
  196. // DataQ4_K returns the underlying Q4_K block slice
  197. func (t *Tensor) DataQ4_K() []tensor.BlockQ4_K {
  198. return t.dataQ4_K
  199. }
  200. // DataQ3_K returns the underlying Q3_K block slice
  201. func (t *Tensor) DataQ3_K() []tensor.BlockQ3_K {
  202. return t.dataQ3_K
  203. }
  204. func (t *Tensor) DataQ5_K() []tensor.BlockQ5_K {
  205. return t.dataQ5_K
  206. }
  207. // DataQ6_K returns the underlying Q6_K block slice
  208. func (t *Tensor) DataQ6_K() []tensor.BlockQ6_K {
  209. return t.dataQ6_K
  210. }
  211. // DataQ8_K returns the underlying Q8_K block slice
  212. func (t *Tensor) DataQ8_K() []tensor.BlockQ8_K {
  213. return t.dataQ8_K
  214. }
  215. // DataQ2_K returns the underlying Q2_K block slice
  216. func (t *Tensor) DataQ2_K() []tensor.BlockQ2_K {
  217. return t.dataQ2_K
  218. }