| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- package kvcache
- import (
- "testing"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/tensor"
- )
- func TestPagedKVCacheAppendViewsIncludeUncommitted(t *testing.T) {
- pool, err := NewBlockPool(BlockPoolConfig{
- NumLayers: 1,
- NumKVHeads: 1,
- HeadDim: 2,
- BlockSize: 4,
- NumBlocks: 4,
- Device: tensor.CPU,
- GPU: 0,
- })
- if err != nil {
- t.Fatalf("NewBlockPool: %v", err)
- }
- cache := NewPagedKVCache(pool, PagedCacheConfig{
- NumLayers: 1,
- NumKVHeads: 1,
- HeadDim: 2,
- BlockSize: 4,
- MaxSeqLen: 16,
- Device: tensor.CPU,
- GPU: 0,
- }, "req")
- defer cache.Free()
- k := cpu.NewTensor(tensor.Shape{2, 2}, []float32{1, 2, 3, 4})
- v := cpu.NewTensor(tensor.Shape{2, 2}, []float32{5, 6, 7, 8})
- views, start, err := cache.Append(0, k, v)
- if err != nil {
- t.Fatalf("Append: %v", err)
- }
- if start != 0 {
- t.Fatalf("expected start 0, got %d", start)
- }
- if cache.SeqLen() != 0 {
- t.Fatalf("expected SeqLen 0 before commit, got %d", cache.SeqLen())
- }
- if len(views) == 0 {
- t.Fatalf("expected non-empty views")
- }
- last := views[len(views)-1]
- if last.Start+last.Length < 2 {
- t.Fatalf("expected views to include appended tokens, got last view start=%d len=%d", last.Start, last.Length)
- }
- committedViews := cache.Views(0)
- if len(committedViews) != 0 {
- t.Fatalf("expected Views() to be empty before commit, got %d", len(committedViews))
- }
- cache.Commit(2)
- if cache.SeqLen() != 2 {
- t.Fatalf("expected SeqLen 2 after commit, got %d", cache.SeqLen())
- }
- committedViews = cache.Views(0)
- if len(committedViews) == 0 {
- t.Fatalf("expected Views() to be non-empty after commit")
- }
- last = committedViews[len(committedViews)-1]
- if last.Start+last.Length < 2 {
- t.Fatalf("expected committed views to include tokens, got last view start=%d len=%d", last.Start, last.Length)
- }
- }
- func TestPagedKVCacheViewsPackedLayout(t *testing.T) {
- pool, err := NewBlockPool(BlockPoolConfig{
- NumLayers: 1,
- NumKVHeads: 2,
- HeadDim: 2,
- BlockSize: 4,
- NumBlocks: 4,
- Device: tensor.CPU,
- GPU: 0,
- })
- if err != nil {
- t.Fatalf("NewBlockPool: %v", err)
- }
- cache := NewPagedKVCache(pool, PagedCacheConfig{
- NumLayers: 1,
- NumKVHeads: 2,
- HeadDim: 2,
- BlockSize: 4,
- MaxSeqLen: 16,
- Device: tensor.CPU,
- GPU: 0,
- }, "req")
- defer cache.Free()
- // token-major layout: [t][kvHead][d]
- // kvDim = 2 * 2 = 4
- k := cpu.NewTensor(tensor.Shape{2, 4}, []float32{
- 1, 2, 3, 4,
- 5, 6, 7, 8,
- })
- v := cpu.NewTensor(tensor.Shape{2, 4}, []float32{
- 10, 20, 30, 40,
- 50, 60, 70, 80,
- })
- if _, _, err := cache.Append(0, k, v); err != nil {
- t.Fatalf("Append: %v", err)
- }
- // Should be available pre-commit (uses numWritten).
- if cache.SeqLen() != 0 {
- t.Fatalf("expected SeqLen 0 before commit, got %d", cache.SeqLen())
- }
- pviews := cache.ViewsPacked(0)
- if len(pviews) == 0 {
- t.Fatalf("expected non-empty packed views")
- }
- if len(pviews) != 1 {
- t.Fatalf("expected 1 packed view, got %d", len(pviews))
- }
- pv := pviews[0]
- if pv.Start != 0 || pv.Length != 2 {
- t.Fatalf("unexpected packed view range: start=%d len=%d", pv.Start, pv.Length)
- }
- if pv.BlockSize != 4 || pv.HeadDim != 2 || pv.NumKVHeads != 2 {
- t.Fatalf("unexpected packed view meta: block=%d headDim=%d kvHeads=%d", pv.BlockSize, pv.HeadDim, pv.NumKVHeads)
- }
- blockSize := pv.BlockSize
- headDim := pv.HeadDim
- numKVHeads := pv.NumKVHeads
- kvDim := numKVHeads * headDim
- for ti := 0; ti < pv.Length; ti++ {
- for h := 0; h < numKVHeads; h++ {
- for d := 0; d < headDim; d++ {
- src := k.DataFloat32()[ti*kvDim+h*headDim+d]
- dst := pv.K[h*(blockSize*headDim)+ti*headDim+d]
- if src != dst {
- t.Fatalf("K packed mismatch t=%d h=%d d=%d: src=%v dst=%v", ti, h, d, src, dst)
- }
- srcV := v.DataFloat32()[ti*kvDim+h*headDim+d]
- dstV := pv.V[h*(blockSize*headDim)+ti*headDim+d]
- if srcV != dstV {
- t.Fatalf("V packed mismatch t=%d h=%d d=%d: src=%v dst=%v", ti, h, d, srcV, dstV)
- }
- }
- }
- }
- }
|