//go:build cuda package cuda import ( "fmt" "math" "math/rand" "testing" "unsafe" "makarna/pkg/tensor" ) func BenchmarkPagedAttentionF16KV_Decode(b *testing.B) { if !Available() { b.Skip("cuda not available") } const ( gpu = 0 numHeads = 32 numKVHeads = 8 headDim = 128 blockSize = 16 ) const seqLen = 1 kvDim := numKVHeads * headDim scale := float32(1.0 / math.Sqrt(float64(headDim))) kvLens := []int{256, 1024, 2048, 4096} for _, kvLen := range kvLens { b.Run(fmt.Sprintf("kvLen=%d", kvLen), func(b *testing.B) { numBlocks := (kvLen + blockSize - 1) / blockSize if numBlocks <= 0 { b.Fatal("invalid numBlocks") } q, err := NewTensor(tensor.Shape{seqLen, numHeads * headDim}, tensor.Float32, gpu) if err != nil { b.Skipf("cuda alloc failed: %v", err) } defer q.Free() out, err := NewTensor(tensor.Shape{seqLen, numHeads * headDim}, tensor.Float32, gpu) if err != nil { b.Skipf("cuda alloc failed: %v", err) } defer out.Free() // Deterministic Q. r := rand.New(rand.NewSource(1)) hostQ := make([]float32, numHeads*headDim) for i := range hostQ { hostQ[i] = r.Float32()*2 - 1 } if err := q.CopyFrom(hostQ); err != nil { b.Fatalf("CopyFrom Q: %v", err) } // Allocate K/V blocks as F16 and zero-initialize (setup only). zeroHalf := make([]uint16, blockSize*kvDim) kPtrs := make([]uintptr, numBlocks) vPtrs := make([]uintptr, numBlocks) kBlocks := make([]*Tensor, numBlocks) vBlocks := make([]*Tensor, numBlocks) for i := 0; i < numBlocks; i++ { kb, err := NewTensor(tensor.Shape{blockSize, kvDim}, tensor.Float16, gpu) if err != nil { b.Skipf("cuda alloc failed: %v", err) } vb, err := NewTensor(tensor.Shape{blockSize, kvDim}, tensor.Float16, gpu) if err != nil { kb.Free() b.Skipf("cuda alloc failed: %v", err) } kBlocks[i] = kb vBlocks[i] = vb kPtrs[i] = uintptr(kb.Data().(unsafe.Pointer)) vPtrs[i] = uintptr(vb.Data().(unsafe.Pointer)) if err := MemcpyH2D(kb.Data().(unsafe.Pointer), unsafe.Pointer(&zeroHalf[0]), uintptr(len(zeroHalf)*2), gpu); err != nil { b.Fatalf("zero K: %v", err) } if err := MemcpyH2D(vb.Data().(unsafe.Pointer), unsafe.Pointer(&zeroHalf[0]), uintptr(len(zeroHalf)*2), gpu); err != nil { b.Fatalf("zero V: %v", err) } } defer func() { for i := range kBlocks { if kBlocks[i] != nil { kBlocks[i].Free() } if vBlocks[i] != nil { vBlocks[i].Free() } } }() kDev, err := AllocAndCopyPtrTable(kPtrs, gpu) if err != nil { b.Fatalf("AllocAndCopyPtrTable K: %v", err) } defer FreeDevicePtr(kDev) vDev, err := AllocAndCopyPtrTable(vPtrs, gpu) if err != nil { FreeDevicePtr(kDev) b.Fatalf("AllocAndCopyPtrTable V: %v", err) } defer FreeDevicePtr(vDev) startPos := kvLen - 1 b.ResetTimer() for i := 0; i < b.N; i++ { if err := PagedAttentionF32F16KV( q.Data().(unsafe.Pointer), kDev, vDev, out.Data().(unsafe.Pointer), seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize, scale, startPos, gpu, ); err != nil { b.Fatalf("PagedAttentionF32F16KV: %v", err) } } // Ensure all kernels are complete before timing finishes. _ = Synchronize(gpu) }) } }