package tensor import "fmt" // DType represents the data type of a tensor type DType int const ( Float32 DType = 0 Float16 DType = 1 BFloat16 DType = 2 Int8 DType = 3 // Legacy/Internal Int4 DType = 4 // Legacy/Internal Int32 DType = 5 Q4_K DType = 22 Q3_K DType = 23 Q5_K DType = 24 Q6_K DType = 26 Q8_K DType = 27 Q2_K DType = 28 ) func (d DType) String() string { switch d { case Float32: return "Float32" case Float16: return "Float16" case BFloat16: return "BFloat16" case Int8: return "Int8" case Int4: return "Int4" case Int32: return "Int32" case Q4_K: return "Q4_K" case Q3_K: return "Q3_K" case Q5_K: return "Q5_K" case Q6_K: return "Q6_K" case Q8_K: return "Q8_K" case Q2_K: return "Q2_K" default: return fmt.Sprintf("DType(%d)", d) } } func (d DType) Size() int { switch d { case Float32: return 4 case Float16: return 2 case BFloat16: return 2 case Int8: return 1 case Int4: return 0 // bitpacked case Q4_K: return 0 // block based case Q3_K: return 0 // block based case Q5_K: return 0 // block based case Q6_K: return 0 // block based case Q8_K: return 0 // block based case Q2_K: return 0 // block based default: panic("unknown dtype") } } // Shape represents tensor dimensions type Shape []int func (s Shape) NumElements() int { if len(s) == 0 { return 0 } n := 1 for _, d := range s { n *= d } return n } func (s Shape) String() string { return fmt.Sprintf("%v", []int(s)) } // DeviceType represents where tensor data lives type DeviceType int const ( CPU DeviceType = iota CUDA ) // DevicePlacement captures a target device and GPU ordinal (for CUDA). // GPU is ignored for CPU placements. type DevicePlacement struct { Type DeviceType GPU int } // Normalize ensures a valid placement with sane defaults. func (p DevicePlacement) Normalize() DevicePlacement { if p.Type != CUDA { return DevicePlacement{Type: CPU, GPU: -1} } if p.GPU < 0 { return DevicePlacement{Type: CUDA, GPU: 0} } return p } // TensorWithPlacement adds placement to the tensor interface. type TensorWithPlacement interface { Tensor Placement() DevicePlacement } // Tensor is a minimal core interface // Operations are handled by standalone functions in ops/, matmul/, nn/ type Tensor interface { Shape() Shape DType() DType Device() DeviceType Data() interface{} }