| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- package compute
- import (
- "errors"
- "fmt"
- "unsafe"
- "makarna/pkg/backend/cuda"
- "makarna/pkg/tensor"
- )
- // DefaultScratchBytes defines the fallback scratch buffer size (256MB).
- const DefaultScratchBytes = 256 << 20
- // ScratchSpace provides a preallocated GPU buffer for temporary tensors.
- // It avoids repeated cudaMalloc calls during the forward pass by carving
- // out views from a single large allocation.
- type ScratchSpace struct {
- buf *cuda.Tensor
- capacity uintptr
- offset uintptr
- }
- func alignUp(offset, align uintptr) uintptr {
- if align == 0 {
- return offset
- }
- rem := offset % align
- if rem == 0 {
- return offset
- }
- return offset + (align - rem)
- }
- // NewScratchSpace allocates a scratch buffer on the given GPU.
- // If bytes is <= 0, DefaultScratchBytes is used.
- func NewScratchSpace(gpu int, bytes int) (*ScratchSpace, error) {
- if gpu < 0 {
- gpu = 0
- }
- if bytes <= 0 {
- bytes = DefaultScratchBytes
- }
- elemCount := bytes / tensor.Float32.Size()
- if elemCount <= 0 {
- return nil, fmt.Errorf("scratch size too small: %d bytes", bytes)
- }
- buf, err := cuda.NewTensor(tensor.Shape{elemCount}, tensor.Float32, gpu)
- if err != nil {
- return nil, fmt.Errorf("alloc scratch buffer: %w", err)
- }
- return &ScratchSpace{
- buf: buf,
- capacity: uintptr(elemCount * tensor.Float32.Size()),
- offset: 0,
- }, nil
- }
- // GPU returns the GPU id backing this scratch space.
- func (s *ScratchSpace) GPU() int {
- if s == nil || s.buf == nil {
- return -1
- }
- return s.buf.GPU()
- }
- // Reset rewinds the allocation offset to reuse the buffer.
- func (s *ScratchSpace) Reset() {
- if s != nil {
- s.offset = 0
- }
- }
- func (s *ScratchSpace) Free() {
- if s != nil && s.buf != nil {
- s.buf.Free()
- s.buf = nil
- s.capacity = 0
- s.offset = 0
- }
- }
- // GetTensor returns an activation backed by a view into the scratch buffer.
- // Only float32 is supported because the buffer is allocated in float32.
- func (s *ScratchSpace) GetTensor(shape tensor.Shape, dtype tensor.DType) (*Activation, error) {
- if s == nil {
- return nil, errors.New("scratch space is nil")
- }
- if dtype != tensor.Float32 {
- return nil, fmt.Errorf("scratch space only supports float32, got %v", dtype)
- }
- required := uintptr(shape.NumElements() * dtype.Size())
- if required == 0 {
- return nil, errors.New("requested zero-sized tensor from scratch space")
- }
- // Keep allocations reasonably aligned for vectorized loads/stores.
- aligned := alignUp(s.offset, 256)
- if aligned+required > s.capacity {
- return nil, fmt.Errorf("scratch space exhausted: need %d bytes (offset %d / cap %d)", required, aligned, s.capacity)
- }
- view, err := s.buf.ViewAt(shape, aligned)
- if err != nil {
- return nil, err
- }
- s.offset = aligned + required
- return NewActivationFrom(view), nil
- }
- // UnsafePointer exposes the raw pointer to the scratch buffer start.
- // Primarily for debugging or advanced use; avoid in regular code.
- func (s *ScratchSpace) UnsafePointer() unsafe.Pointer {
- if s == nil || s.buf == nil {
- return nil
- }
- if p, ok := s.buf.Data().(unsafe.Pointer); ok {
- return p
- }
- return nil
- }
- // GetInt32Slice allocates a slice of int32s from the scratch buffer.
- // Returns the raw device pointer.
- func (s *ScratchSpace) GetInt32Slice(count int) (unsafe.Pointer, error) {
- if s == nil {
- return nil, errors.New("scratch space is nil")
- }
- size := uintptr(count * 4) // 4 bytes per int32
- aligned := alignUp(s.offset, 16)
- if aligned+size > s.capacity {
- return nil, fmt.Errorf("scratch space exhausted: need %d bytes", size)
- }
- ptr := unsafe.Pointer(uintptr(s.UnsafePointer()) + aligned)
- s.offset = aligned + size
- return ptr, nil
- }
- // GetUintptrSlice allocates a slice of uintptrs from the scratch buffer.
- // Returns the raw device pointer.
- // Ensures the returned pointer is aligned to uintptr size.
- func (s *ScratchSpace) GetUintptrSlice(count int) (unsafe.Pointer, error) {
- if s == nil {
- return nil, errors.New("scratch space is nil")
- }
- if count <= 0 {
- return nil, errors.New("count must be > 0")
- }
- align := uintptr(unsafe.Sizeof(uintptr(0)))
- aligned := alignUp(s.offset, align)
- size := uintptr(count) * align
- if aligned+size > s.capacity {
- return nil, fmt.Errorf("scratch space exhausted: need %d bytes", size)
- }
- ptr := unsafe.Pointer(uintptr(s.UnsafePointer()) + aligned)
- s.offset = aligned + size
- return ptr, nil
- }
|