memory.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package compute
  2. import (
  3. "errors"
  4. "fmt"
  5. "unsafe"
  6. "makarna/pkg/backend/cuda"
  7. "makarna/pkg/tensor"
  8. )
  9. // DefaultScratchBytes defines the fallback scratch buffer size (256MB).
  10. const DefaultScratchBytes = 256 << 20
  11. // ScratchSpace provides a preallocated GPU buffer for temporary tensors.
  12. // It avoids repeated cudaMalloc calls during the forward pass by carving
  13. // out views from a single large allocation.
  14. type ScratchSpace struct {
  15. buf *cuda.Tensor
  16. capacity uintptr
  17. offset uintptr
  18. }
  19. func alignUp(offset, align uintptr) uintptr {
  20. if align == 0 {
  21. return offset
  22. }
  23. rem := offset % align
  24. if rem == 0 {
  25. return offset
  26. }
  27. return offset + (align - rem)
  28. }
  29. // NewScratchSpace allocates a scratch buffer on the given GPU.
  30. // If bytes is <= 0, DefaultScratchBytes is used.
  31. func NewScratchSpace(gpu int, bytes int) (*ScratchSpace, error) {
  32. if gpu < 0 {
  33. gpu = 0
  34. }
  35. if bytes <= 0 {
  36. bytes = DefaultScratchBytes
  37. }
  38. elemCount := bytes / tensor.Float32.Size()
  39. if elemCount <= 0 {
  40. return nil, fmt.Errorf("scratch size too small: %d bytes", bytes)
  41. }
  42. buf, err := cuda.NewTensor(tensor.Shape{elemCount}, tensor.Float32, gpu)
  43. if err != nil {
  44. return nil, fmt.Errorf("alloc scratch buffer: %w", err)
  45. }
  46. return &ScratchSpace{
  47. buf: buf,
  48. capacity: uintptr(elemCount * tensor.Float32.Size()),
  49. offset: 0,
  50. }, nil
  51. }
  52. // GPU returns the GPU id backing this scratch space.
  53. func (s *ScratchSpace) GPU() int {
  54. if s == nil || s.buf == nil {
  55. return -1
  56. }
  57. return s.buf.GPU()
  58. }
  59. // Reset rewinds the allocation offset to reuse the buffer.
  60. func (s *ScratchSpace) Reset() {
  61. if s != nil {
  62. s.offset = 0
  63. }
  64. }
  65. func (s *ScratchSpace) Free() {
  66. if s != nil && s.buf != nil {
  67. s.buf.Free()
  68. s.buf = nil
  69. s.capacity = 0
  70. s.offset = 0
  71. }
  72. }
  73. // GetTensor returns an activation backed by a view into the scratch buffer.
  74. // Only float32 is supported because the buffer is allocated in float32.
  75. func (s *ScratchSpace) GetTensor(shape tensor.Shape, dtype tensor.DType) (*Activation, error) {
  76. if s == nil {
  77. return nil, errors.New("scratch space is nil")
  78. }
  79. if dtype != tensor.Float32 {
  80. return nil, fmt.Errorf("scratch space only supports float32, got %v", dtype)
  81. }
  82. required := uintptr(shape.NumElements() * dtype.Size())
  83. if required == 0 {
  84. return nil, errors.New("requested zero-sized tensor from scratch space")
  85. }
  86. // Keep allocations reasonably aligned for vectorized loads/stores.
  87. aligned := alignUp(s.offset, 256)
  88. if aligned+required > s.capacity {
  89. return nil, fmt.Errorf("scratch space exhausted: need %d bytes (offset %d / cap %d)", required, aligned, s.capacity)
  90. }
  91. view, err := s.buf.ViewAt(shape, aligned)
  92. if err != nil {
  93. return nil, err
  94. }
  95. s.offset = aligned + required
  96. return NewActivationFrom(view), nil
  97. }
  98. // UnsafePointer exposes the raw pointer to the scratch buffer start.
  99. // Primarily for debugging or advanced use; avoid in regular code.
  100. func (s *ScratchSpace) UnsafePointer() unsafe.Pointer {
  101. if s == nil || s.buf == nil {
  102. return nil
  103. }
  104. if p, ok := s.buf.Data().(unsafe.Pointer); ok {
  105. return p
  106. }
  107. return nil
  108. }
  109. // GetInt32Slice allocates a slice of int32s from the scratch buffer.
  110. // Returns the raw device pointer.
  111. func (s *ScratchSpace) GetInt32Slice(count int) (unsafe.Pointer, error) {
  112. if s == nil {
  113. return nil, errors.New("scratch space is nil")
  114. }
  115. size := uintptr(count * 4) // 4 bytes per int32
  116. aligned := alignUp(s.offset, 16)
  117. if aligned+size > s.capacity {
  118. return nil, fmt.Errorf("scratch space exhausted: need %d bytes", size)
  119. }
  120. ptr := unsafe.Pointer(uintptr(s.UnsafePointer()) + aligned)
  121. s.offset = aligned + size
  122. return ptr, nil
  123. }
  124. // GetUintptrSlice allocates a slice of uintptrs from the scratch buffer.
  125. // Returns the raw device pointer.
  126. // Ensures the returned pointer is aligned to uintptr size.
  127. func (s *ScratchSpace) GetUintptrSlice(count int) (unsafe.Pointer, error) {
  128. if s == nil {
  129. return nil, errors.New("scratch space is nil")
  130. }
  131. if count <= 0 {
  132. return nil, errors.New("count must be > 0")
  133. }
  134. align := uintptr(unsafe.Sizeof(uintptr(0)))
  135. aligned := alignUp(s.offset, align)
  136. size := uintptr(count) * align
  137. if aligned+size > s.capacity {
  138. return nil, fmt.Errorf("scratch space exhausted: need %d bytes", size)
  139. }
  140. ptr := unsafe.Pointer(uintptr(s.UnsafePointer()) + aligned)
  141. s.offset = aligned + size
  142. return ptr, nil
  143. }