package compute import ( "errors" "fmt" "unsafe" "makarna/pkg/backend/cuda" "makarna/pkg/tensor" ) // DefaultScratchBytes defines the fallback scratch buffer size (256MB). const DefaultScratchBytes = 256 << 20 // ScratchSpace provides a preallocated GPU buffer for temporary tensors. // It avoids repeated cudaMalloc calls during the forward pass by carving // out views from a single large allocation. type ScratchSpace struct { buf *cuda.Tensor capacity uintptr offset uintptr } func alignUp(offset, align uintptr) uintptr { if align == 0 { return offset } rem := offset % align if rem == 0 { return offset } return offset + (align - rem) } // NewScratchSpace allocates a scratch buffer on the given GPU. // If bytes is <= 0, DefaultScratchBytes is used. func NewScratchSpace(gpu int, bytes int) (*ScratchSpace, error) { if gpu < 0 { gpu = 0 } if bytes <= 0 { bytes = DefaultScratchBytes } elemCount := bytes / tensor.Float32.Size() if elemCount <= 0 { return nil, fmt.Errorf("scratch size too small: %d bytes", bytes) } buf, err := cuda.NewTensor(tensor.Shape{elemCount}, tensor.Float32, gpu) if err != nil { return nil, fmt.Errorf("alloc scratch buffer: %w", err) } return &ScratchSpace{ buf: buf, capacity: uintptr(elemCount * tensor.Float32.Size()), offset: 0, }, nil } // GPU returns the GPU id backing this scratch space. func (s *ScratchSpace) GPU() int { if s == nil || s.buf == nil { return -1 } return s.buf.GPU() } // Reset rewinds the allocation offset to reuse the buffer. func (s *ScratchSpace) Reset() { if s != nil { s.offset = 0 } } func (s *ScratchSpace) Free() { if s != nil && s.buf != nil { s.buf.Free() s.buf = nil s.capacity = 0 s.offset = 0 } } // GetTensor returns an activation backed by a view into the scratch buffer. // Only float32 is supported because the buffer is allocated in float32. func (s *ScratchSpace) GetTensor(shape tensor.Shape, dtype tensor.DType) (*Activation, error) { if s == nil { return nil, errors.New("scratch space is nil") } if dtype != tensor.Float32 { return nil, fmt.Errorf("scratch space only supports float32, got %v", dtype) } required := uintptr(shape.NumElements() * dtype.Size()) if required == 0 { return nil, errors.New("requested zero-sized tensor from scratch space") } // Keep allocations reasonably aligned for vectorized loads/stores. aligned := alignUp(s.offset, 256) if aligned+required > s.capacity { return nil, fmt.Errorf("scratch space exhausted: need %d bytes (offset %d / cap %d)", required, aligned, s.capacity) } view, err := s.buf.ViewAt(shape, aligned) if err != nil { return nil, err } s.offset = aligned + required return NewActivationFrom(view), nil } // UnsafePointer exposes the raw pointer to the scratch buffer start. // Primarily for debugging or advanced use; avoid in regular code. func (s *ScratchSpace) UnsafePointer() unsafe.Pointer { if s == nil || s.buf == nil { return nil } if p, ok := s.buf.Data().(unsafe.Pointer); ok { return p } return nil } // GetInt32Slice allocates a slice of int32s from the scratch buffer. // Returns the raw device pointer. func (s *ScratchSpace) GetInt32Slice(count int) (unsafe.Pointer, error) { if s == nil { return nil, errors.New("scratch space is nil") } size := uintptr(count * 4) // 4 bytes per int32 aligned := alignUp(s.offset, 16) if aligned+size > s.capacity { return nil, fmt.Errorf("scratch space exhausted: need %d bytes", size) } ptr := unsafe.Pointer(uintptr(s.UnsafePointer()) + aligned) s.offset = aligned + size return ptr, nil } // GetUintptrSlice allocates a slice of uintptrs from the scratch buffer. // Returns the raw device pointer. // Ensures the returned pointer is aligned to uintptr size. func (s *ScratchSpace) GetUintptrSlice(count int) (unsafe.Pointer, error) { if s == nil { return nil, errors.New("scratch space is nil") } if count <= 0 { return nil, errors.New("count must be > 0") } align := uintptr(unsafe.Sizeof(uintptr(0))) aligned := alignUp(s.offset, align) size := uintptr(count) * align if aligned+size > s.capacity { return nil, fmt.Errorf("scratch space exhausted: need %d bytes", size) } ptr := unsafe.Pointer(uintptr(s.UnsafePointer()) + aligned) s.offset = aligned + size return ptr, nil }