1
0

attention_cached_test.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package nn
  2. import (
  3. "math"
  4. "testing"
  5. "makarna/pkg/backend/cpu"
  6. "makarna/pkg/kvcache"
  7. "makarna/pkg/tensor"
  8. )
  9. func TestCausalAttentionBlocksMatchesContiguous(t *testing.T) {
  10. numHeads, numKVHeads, headDim := 1, 1, 1
  11. startPos := 1
  12. // Two new tokens attending over three total tokens (one past + two new).
  13. q := cpu.NewTensor(tensor.Shape{2, 1}, []float32{0.5, 1.0})
  14. kAll := cpu.NewTensor(tensor.Shape{3, 1}, []float32{1, 2, 3})
  15. vAll := cpu.NewTensor(tensor.Shape{3, 1}, []float32{10, 20, 30})
  16. outContig := cpu.NewTensor(tensor.Shape{2, 1}, nil)
  17. if err := CausalAttentionCached(q, kAll, vAll, outContig, numHeads, numKVHeads, headDim, startPos); err != nil {
  18. t.Fatalf("contiguous attention failed: %v", err)
  19. }
  20. // Build block view equivalent to the contiguous tensors.
  21. blockK := cpu.NewTensor(tensor.Shape{4, 1}, []float32{1, 2, 3, 0})
  22. blockV := cpu.NewTensor(tensor.Shape{4, 1}, []float32{10, 20, 30, 0})
  23. view := kvcache.View{
  24. K: blockK,
  25. V: blockV,
  26. Start: 0,
  27. Length: 3,
  28. Device: tensor.CPU,
  29. }
  30. outBlocks := cpu.NewTensor(tensor.Shape{2, 1}, nil)
  31. if err := CausalAttentionBlocks(q, []kvcache.View{view}, outBlocks, numHeads, numKVHeads, headDim, startPos); err != nil {
  32. t.Fatalf("block attention failed: %v", err)
  33. }
  34. for i := range outContig.DataFloat32() {
  35. if diff := math.Abs(float64(outContig.DataFloat32()[i] - outBlocks.DataFloat32()[i])); diff > 1e-5 {
  36. t.Fatalf("mismatch at %d: contiguous=%v blocks=%v", i, outContig.DataFloat32()[i], outBlocks.DataFloat32()[i])
  37. }
  38. }
  39. }
  40. func TestCausalAttentionPackedMatchesBlocks(t *testing.T) {
  41. numHeads, numKVHeads, headDim := 4, 2, 8
  42. newTokens := 2
  43. startPos := 4
  44. blockSize := 8
  45. kvDim := numKVHeads * headDim
  46. qData := make([]float32, newTokens*numHeads*headDim)
  47. for i := range qData {
  48. qData[i] = float32(i%7) / 7
  49. }
  50. q := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, qData)
  51. // total KV length includes past + current
  52. total := startPos + newTokens
  53. kData := make([]float32, total*kvDim)
  54. vData := make([]float32, total*kvDim)
  55. for i := range kData {
  56. kData[i] = float32((i%17)-8) / 9
  57. vData[i] = float32((i%19)-9) / 10
  58. }
  59. views := make([]kvcache.View, 0, (total+blockSize-1)/blockSize)
  60. pviews := make([]kvcache.PackedView, 0, (total+blockSize-1)/blockSize)
  61. for start := 0; start < total; start += blockSize {
  62. length := blockSize
  63. if start+length > total {
  64. length = total - start
  65. }
  66. kBlkData := make([]float32, blockSize*kvDim)
  67. vBlkData := make([]float32, blockSize*kvDim)
  68. copy(kBlkData, kData[start*kvDim:(start+length)*kvDim])
  69. copy(vBlkData, vData[start*kvDim:(start+length)*kvDim])
  70. kBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, kBlkData)
  71. vBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, vBlkData)
  72. views = append(views, kvcache.View{K: kBlk, V: vBlk, Start: start, Length: length, Device: tensor.CPU})
  73. pk := make([]float32, numKVHeads*blockSize*headDim)
  74. pv := make([]float32, numKVHeads*blockSize*headDim)
  75. for ti := 0; ti < length; ti++ {
  76. baseTok := (start + ti) * kvDim
  77. for h := 0; h < numKVHeads; h++ {
  78. srcBase := baseTok + h*headDim
  79. dstBase := h*(blockSize*headDim) + ti*headDim
  80. copy(pk[dstBase:dstBase+headDim], kData[srcBase:srcBase+headDim])
  81. copy(pv[dstBase:dstBase+headDim], vData[srcBase:srcBase+headDim])
  82. }
  83. }
  84. pviews = append(pviews, kvcache.PackedView{K: pk, V: pv, Start: start, Length: length, BlockSize: blockSize, HeadDim: headDim, NumKVHeads: numKVHeads})
  85. }
  86. outBlocks := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil)
  87. outPacked := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil)
  88. if err := CausalAttentionBlocks(q, views, outBlocks, numHeads, numKVHeads, headDim, startPos); err != nil {
  89. t.Fatalf("blocks attention failed: %v", err)
  90. }
  91. if err := CausalAttentionPackedBlocks(q, pviews, outPacked, numHeads, numKVHeads, headDim, startPos); err != nil {
  92. t.Fatalf("packed attention failed: %v", err)
  93. }
  94. for i := range outBlocks.DataFloat32() {
  95. diff := math.Abs(float64(outBlocks.DataFloat32()[i] - outPacked.DataFloat32()[i]))
  96. if diff > 1e-5 {
  97. t.Fatalf("mismatch at %d: blocks=%v packed=%v", i, outBlocks.DataFloat32()[i], outPacked.DataFloat32()[i])
  98. }
  99. }
  100. }