1
0

attention_bench_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. //go:build cuda
  2. package cuda
  3. import (
  4. "fmt"
  5. "math"
  6. "math/rand"
  7. "testing"
  8. "unsafe"
  9. "makarna/pkg/tensor"
  10. )
  11. func BenchmarkPagedAttentionF16KV_Decode(b *testing.B) {
  12. if !Available() {
  13. b.Skip("cuda not available")
  14. }
  15. const (
  16. gpu = 0
  17. numHeads = 32
  18. numKVHeads = 8
  19. headDim = 128
  20. blockSize = 16
  21. )
  22. const seqLen = 1
  23. kvDim := numKVHeads * headDim
  24. scale := float32(1.0 / math.Sqrt(float64(headDim)))
  25. kvLens := []int{256, 1024, 2048, 4096}
  26. for _, kvLen := range kvLens {
  27. b.Run(fmt.Sprintf("kvLen=%d", kvLen), func(b *testing.B) {
  28. numBlocks := (kvLen + blockSize - 1) / blockSize
  29. if numBlocks <= 0 {
  30. b.Fatal("invalid numBlocks")
  31. }
  32. q, err := NewTensor(tensor.Shape{seqLen, numHeads * headDim}, tensor.Float32, gpu)
  33. if err != nil {
  34. b.Skipf("cuda alloc failed: %v", err)
  35. }
  36. defer q.Free()
  37. out, err := NewTensor(tensor.Shape{seqLen, numHeads * headDim}, tensor.Float32, gpu)
  38. if err != nil {
  39. b.Skipf("cuda alloc failed: %v", err)
  40. }
  41. defer out.Free()
  42. // Deterministic Q.
  43. r := rand.New(rand.NewSource(1))
  44. hostQ := make([]float32, numHeads*headDim)
  45. for i := range hostQ {
  46. hostQ[i] = r.Float32()*2 - 1
  47. }
  48. if err := q.CopyFrom(hostQ); err != nil {
  49. b.Fatalf("CopyFrom Q: %v", err)
  50. }
  51. // Allocate K/V blocks as F16 and zero-initialize (setup only).
  52. zeroHalf := make([]uint16, blockSize*kvDim)
  53. kPtrs := make([]uintptr, numBlocks)
  54. vPtrs := make([]uintptr, numBlocks)
  55. kBlocks := make([]*Tensor, numBlocks)
  56. vBlocks := make([]*Tensor, numBlocks)
  57. for i := 0; i < numBlocks; i++ {
  58. kb, err := NewTensor(tensor.Shape{blockSize, kvDim}, tensor.Float16, gpu)
  59. if err != nil {
  60. b.Skipf("cuda alloc failed: %v", err)
  61. }
  62. vb, err := NewTensor(tensor.Shape{blockSize, kvDim}, tensor.Float16, gpu)
  63. if err != nil {
  64. kb.Free()
  65. b.Skipf("cuda alloc failed: %v", err)
  66. }
  67. kBlocks[i] = kb
  68. vBlocks[i] = vb
  69. kPtrs[i] = uintptr(kb.Data().(unsafe.Pointer))
  70. vPtrs[i] = uintptr(vb.Data().(unsafe.Pointer))
  71. if err := MemcpyH2D(kb.Data().(unsafe.Pointer), unsafe.Pointer(&zeroHalf[0]), uintptr(len(zeroHalf)*2), gpu); err != nil {
  72. b.Fatalf("zero K: %v", err)
  73. }
  74. if err := MemcpyH2D(vb.Data().(unsafe.Pointer), unsafe.Pointer(&zeroHalf[0]), uintptr(len(zeroHalf)*2), gpu); err != nil {
  75. b.Fatalf("zero V: %v", err)
  76. }
  77. }
  78. defer func() {
  79. for i := range kBlocks {
  80. if kBlocks[i] != nil {
  81. kBlocks[i].Free()
  82. }
  83. if vBlocks[i] != nil {
  84. vBlocks[i].Free()
  85. }
  86. }
  87. }()
  88. kDev, err := AllocAndCopyPtrTable(kPtrs, gpu)
  89. if err != nil {
  90. b.Fatalf("AllocAndCopyPtrTable K: %v", err)
  91. }
  92. defer FreeDevicePtr(kDev)
  93. vDev, err := AllocAndCopyPtrTable(vPtrs, gpu)
  94. if err != nil {
  95. FreeDevicePtr(kDev)
  96. b.Fatalf("AllocAndCopyPtrTable V: %v", err)
  97. }
  98. defer FreeDevicePtr(vDev)
  99. startPos := kvLen - 1
  100. b.ResetTimer()
  101. for i := 0; i < b.N; i++ {
  102. if err := PagedAttentionF32F16KV(
  103. q.Data().(unsafe.Pointer),
  104. kDev,
  105. vDev,
  106. out.Data().(unsafe.Pointer),
  107. seqLen,
  108. kvLen,
  109. numHeads,
  110. numKVHeads,
  111. headDim,
  112. blockSize,
  113. scale,
  114. startPos,
  115. gpu,
  116. ); err != nil {
  117. b.Fatalf("PagedAttentionF32F16KV: %v", err)
  118. }
  119. }
  120. // Ensure all kernels are complete before timing finishes.
  121. _ = Synchronize(gpu)
  122. })
  123. }
  124. }