paged_cache_test.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. package kvcache
  2. import (
  3. "testing"
  4. "makarna/pkg/backend/cpu"
  5. "makarna/pkg/tensor"
  6. )
  7. func TestPagedKVCacheAppendViewsIncludeUncommitted(t *testing.T) {
  8. pool, err := NewBlockPool(BlockPoolConfig{
  9. NumLayers: 1,
  10. NumKVHeads: 1,
  11. HeadDim: 2,
  12. BlockSize: 4,
  13. NumBlocks: 4,
  14. Device: tensor.CPU,
  15. GPU: 0,
  16. })
  17. if err != nil {
  18. t.Fatalf("NewBlockPool: %v", err)
  19. }
  20. cache := NewPagedKVCache(pool, PagedCacheConfig{
  21. NumLayers: 1,
  22. NumKVHeads: 1,
  23. HeadDim: 2,
  24. BlockSize: 4,
  25. MaxSeqLen: 16,
  26. Device: tensor.CPU,
  27. GPU: 0,
  28. }, "req")
  29. defer cache.Free()
  30. k := cpu.NewTensor(tensor.Shape{2, 2}, []float32{1, 2, 3, 4})
  31. v := cpu.NewTensor(tensor.Shape{2, 2}, []float32{5, 6, 7, 8})
  32. views, start, err := cache.Append(0, k, v)
  33. if err != nil {
  34. t.Fatalf("Append: %v", err)
  35. }
  36. if start != 0 {
  37. t.Fatalf("expected start 0, got %d", start)
  38. }
  39. if cache.SeqLen() != 0 {
  40. t.Fatalf("expected SeqLen 0 before commit, got %d", cache.SeqLen())
  41. }
  42. if len(views) == 0 {
  43. t.Fatalf("expected non-empty views")
  44. }
  45. last := views[len(views)-1]
  46. if last.Start+last.Length < 2 {
  47. t.Fatalf("expected views to include appended tokens, got last view start=%d len=%d", last.Start, last.Length)
  48. }
  49. committedViews := cache.Views(0)
  50. if len(committedViews) != 0 {
  51. t.Fatalf("expected Views() to be empty before commit, got %d", len(committedViews))
  52. }
  53. cache.Commit(2)
  54. if cache.SeqLen() != 2 {
  55. t.Fatalf("expected SeqLen 2 after commit, got %d", cache.SeqLen())
  56. }
  57. committedViews = cache.Views(0)
  58. if len(committedViews) == 0 {
  59. t.Fatalf("expected Views() to be non-empty after commit")
  60. }
  61. last = committedViews[len(committedViews)-1]
  62. if last.Start+last.Length < 2 {
  63. t.Fatalf("expected committed views to include tokens, got last view start=%d len=%d", last.Start, last.Length)
  64. }
  65. }
  66. func TestPagedKVCacheViewsPackedLayout(t *testing.T) {
  67. pool, err := NewBlockPool(BlockPoolConfig{
  68. NumLayers: 1,
  69. NumKVHeads: 2,
  70. HeadDim: 2,
  71. BlockSize: 4,
  72. NumBlocks: 4,
  73. Device: tensor.CPU,
  74. GPU: 0,
  75. })
  76. if err != nil {
  77. t.Fatalf("NewBlockPool: %v", err)
  78. }
  79. cache := NewPagedKVCache(pool, PagedCacheConfig{
  80. NumLayers: 1,
  81. NumKVHeads: 2,
  82. HeadDim: 2,
  83. BlockSize: 4,
  84. MaxSeqLen: 16,
  85. Device: tensor.CPU,
  86. GPU: 0,
  87. }, "req")
  88. defer cache.Free()
  89. // token-major layout: [t][kvHead][d]
  90. // kvDim = 2 * 2 = 4
  91. k := cpu.NewTensor(tensor.Shape{2, 4}, []float32{
  92. 1, 2, 3, 4,
  93. 5, 6, 7, 8,
  94. })
  95. v := cpu.NewTensor(tensor.Shape{2, 4}, []float32{
  96. 10, 20, 30, 40,
  97. 50, 60, 70, 80,
  98. })
  99. if _, _, err := cache.Append(0, k, v); err != nil {
  100. t.Fatalf("Append: %v", err)
  101. }
  102. // Should be available pre-commit (uses numWritten).
  103. if cache.SeqLen() != 0 {
  104. t.Fatalf("expected SeqLen 0 before commit, got %d", cache.SeqLen())
  105. }
  106. pviews := cache.ViewsPacked(0)
  107. if len(pviews) == 0 {
  108. t.Fatalf("expected non-empty packed views")
  109. }
  110. if len(pviews) != 1 {
  111. t.Fatalf("expected 1 packed view, got %d", len(pviews))
  112. }
  113. pv := pviews[0]
  114. if pv.Start != 0 || pv.Length != 2 {
  115. t.Fatalf("unexpected packed view range: start=%d len=%d", pv.Start, pv.Length)
  116. }
  117. if pv.BlockSize != 4 || pv.HeadDim != 2 || pv.NumKVHeads != 2 {
  118. t.Fatalf("unexpected packed view meta: block=%d headDim=%d kvHeads=%d", pv.BlockSize, pv.HeadDim, pv.NumKVHeads)
  119. }
  120. blockSize := pv.BlockSize
  121. headDim := pv.HeadDim
  122. numKVHeads := pv.NumKVHeads
  123. kvDim := numKVHeads * headDim
  124. for ti := 0; ti < pv.Length; ti++ {
  125. for h := 0; h < numKVHeads; h++ {
  126. for d := 0; d < headDim; d++ {
  127. src := k.DataFloat32()[ti*kvDim+h*headDim+d]
  128. dst := pv.K[h*(blockSize*headDim)+ti*headDim+d]
  129. if src != dst {
  130. t.Fatalf("K packed mismatch t=%d h=%d d=%d: src=%v dst=%v", ti, h, d, src, dst)
  131. }
  132. srcV := v.DataFloat32()[ti*kvDim+h*headDim+d]
  133. dstV := pv.V[h*(blockSize*headDim)+ti*headDim+d]
  134. if srcV != dstV {
  135. t.Fatalf("V packed mismatch t=%d h=%d d=%d: src=%v dst=%v", ti, h, d, srcV, dstV)
  136. }
  137. }
  138. }
  139. }
  140. }