1
0

attention_batch_test.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package nn
  2. import (
  3. "math"
  4. "math/rand"
  5. "testing"
  6. "makarna/pkg/backend/cpu"
  7. "makarna/pkg/kvcache"
  8. "makarna/pkg/tensor"
  9. )
  10. func TestCausalAttentionPackedBlocksBatch_MatchesPerToken(t *testing.T) {
  11. numTokens := 7
  12. numHeads := 8
  13. numKVHeads := 2
  14. headDim := 16
  15. blockSize := 4
  16. qStride := numHeads * headDim
  17. queryPos := []int{0, 3, 5, 1, 7, 2, 9}
  18. if len(queryPos) != numTokens {
  19. t.Fatalf("test bug: queryPos len %d != numTokens %d", len(queryPos), numTokens)
  20. }
  21. rng := rand.New(rand.NewSource(1234))
  22. viewsByToken := make([][]kvcache.PackedView, numTokens)
  23. for tok := 0; tok < numTokens; tok++ {
  24. kvLen := queryPos[tok] + 1
  25. if kvLen <= 0 {
  26. t.Fatalf("kvLen must be > 0, got %d", kvLen)
  27. }
  28. numBlocks := (kvLen + blockSize - 1) / blockSize
  29. views := make([]kvcache.PackedView, 0, numBlocks)
  30. for b := 0; b < numBlocks; b++ {
  31. start := b * blockSize
  32. length := blockSize
  33. if start+length > kvLen {
  34. length = kvLen - start
  35. }
  36. if length <= 0 {
  37. break
  38. }
  39. blkStride := blockSize * headDim
  40. k := make([]float32, numKVHeads*blkStride)
  41. v := make([]float32, numKVHeads*blkStride)
  42. for i := range k {
  43. k[i] = rng.Float32()*2 - 1
  44. v[i] = rng.Float32()*2 - 1
  45. }
  46. views = append(views, kvcache.PackedView{
  47. K: k,
  48. V: v,
  49. Start: start,
  50. Length: length,
  51. BlockSize: blockSize,
  52. HeadDim: headDim,
  53. NumKVHeads: numKVHeads,
  54. })
  55. }
  56. viewsByToken[tok] = views
  57. }
  58. q := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil)
  59. qData := q.DataFloat32()
  60. for i := range qData {
  61. qData[i] = rng.Float32()*2 - 1
  62. }
  63. outBatch := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil)
  64. if err := CausalAttentionPackedBlocksBatch(q, viewsByToken, outBatch, numHeads, numKVHeads, headDim, queryPos); err != nil {
  65. t.Fatalf("CausalAttentionPackedBlocksBatch: %v", err)
  66. }
  67. outRef := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil)
  68. outRefData := outRef.DataFloat32()
  69. for tok := 0; tok < numTokens; tok++ {
  70. qRow := cpu.NewTensor(tensor.Shape{1, qStride}, qData[tok*qStride:(tok+1)*qStride])
  71. outRow := cpu.NewTensor(tensor.Shape{1, qStride}, outRefData[tok*qStride:(tok+1)*qStride])
  72. if err := CausalAttentionPackedBlocks(qRow, viewsByToken[tok], outRow, numHeads, numKVHeads, headDim, queryPos[tok]); err != nil {
  73. t.Fatalf("CausalAttentionPackedBlocks tok=%d: %v", tok, err)
  74. }
  75. }
  76. got := outBatch.DataFloat32()
  77. want := outRef.DataFloat32()
  78. if len(got) != len(want) {
  79. t.Fatalf("length mismatch: got %d want %d", len(got), len(want))
  80. }
  81. const tol = 1e-4
  82. for i := range got {
  83. if diff := math.Abs(float64(got[i] - want[i])); diff > tol {
  84. t.Fatalf("mismatch at %d: got=%g want=%g diff=%g", i, got[i], want[i], diff)
  85. }
  86. }
  87. }