tensor.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. package tensor
  2. import "fmt"
  3. // DType represents the data type of a tensor
  4. type DType int
  5. const (
  6. Float32 DType = 0
  7. Float16 DType = 1
  8. BFloat16 DType = 2
  9. Int8 DType = 3 // Legacy/Internal
  10. Int4 DType = 4 // Legacy/Internal
  11. Int32 DType = 5
  12. Q4_K DType = 22
  13. Q3_K DType = 23
  14. Q5_K DType = 24
  15. Q6_K DType = 26
  16. Q8_K DType = 27
  17. Q2_K DType = 28
  18. )
  19. func (d DType) String() string {
  20. switch d {
  21. case Float32:
  22. return "Float32"
  23. case Float16:
  24. return "Float16"
  25. case BFloat16:
  26. return "BFloat16"
  27. case Int8:
  28. return "Int8"
  29. case Int4:
  30. return "Int4"
  31. case Int32:
  32. return "Int32"
  33. case Q4_K:
  34. return "Q4_K"
  35. case Q3_K:
  36. return "Q3_K"
  37. case Q5_K:
  38. return "Q5_K"
  39. case Q6_K:
  40. return "Q6_K"
  41. case Q8_K:
  42. return "Q8_K"
  43. case Q2_K:
  44. return "Q2_K"
  45. default:
  46. return fmt.Sprintf("DType(%d)", d)
  47. }
  48. }
  49. func (d DType) Size() int {
  50. switch d {
  51. case Float32:
  52. return 4
  53. case Float16:
  54. return 2
  55. case BFloat16:
  56. return 2
  57. case Int8:
  58. return 1
  59. case Int4:
  60. return 0 // bitpacked
  61. case Q4_K:
  62. return 0 // block based
  63. case Q3_K:
  64. return 0 // block based
  65. case Q5_K:
  66. return 0 // block based
  67. case Q6_K:
  68. return 0 // block based
  69. case Q8_K:
  70. return 0 // block based
  71. case Q2_K:
  72. return 0 // block based
  73. default:
  74. panic("unknown dtype")
  75. }
  76. }
  77. // Shape represents tensor dimensions
  78. type Shape []int
  79. func (s Shape) NumElements() int {
  80. if len(s) == 0 {
  81. return 0
  82. }
  83. n := 1
  84. for _, d := range s {
  85. n *= d
  86. }
  87. return n
  88. }
  89. func (s Shape) String() string {
  90. return fmt.Sprintf("%v", []int(s))
  91. }
  92. // DeviceType represents where tensor data lives
  93. type DeviceType int
  94. const (
  95. CPU DeviceType = iota
  96. CUDA
  97. )
  98. // DevicePlacement captures a target device and GPU ordinal (for CUDA).
  99. // GPU is ignored for CPU placements.
  100. type DevicePlacement struct {
  101. Type DeviceType
  102. GPU int
  103. }
  104. // Normalize ensures a valid placement with sane defaults.
  105. func (p DevicePlacement) Normalize() DevicePlacement {
  106. if p.Type != CUDA {
  107. return DevicePlacement{Type: CPU, GPU: -1}
  108. }
  109. if p.GPU < 0 {
  110. return DevicePlacement{Type: CUDA, GPU: 0}
  111. }
  112. return p
  113. }
  114. // TensorWithPlacement adds placement to the tensor interface.
  115. type TensorWithPlacement interface {
  116. Tensor
  117. Placement() DevicePlacement
  118. }
  119. // Tensor is a minimal core interface
  120. // Operations are handled by standalone functions in ops/, matmul/, nn/
  121. type Tensor interface {
  122. Shape() Shape
  123. DType() DType
  124. Device() DeviceType
  125. Data() interface{}
  126. }