package nn import ( "math" "testing" "makarna/pkg/backend/cpu" "makarna/pkg/kvcache" "makarna/pkg/tensor" ) func TestCausalAttentionBlocksMatchesContiguous(t *testing.T) { numHeads, numKVHeads, headDim := 1, 1, 1 startPos := 1 // Two new tokens attending over three total tokens (one past + two new). q := cpu.NewTensor(tensor.Shape{2, 1}, []float32{0.5, 1.0}) kAll := cpu.NewTensor(tensor.Shape{3, 1}, []float32{1, 2, 3}) vAll := cpu.NewTensor(tensor.Shape{3, 1}, []float32{10, 20, 30}) outContig := cpu.NewTensor(tensor.Shape{2, 1}, nil) if err := CausalAttentionCached(q, kAll, vAll, outContig, numHeads, numKVHeads, headDim, startPos); err != nil { t.Fatalf("contiguous attention failed: %v", err) } // Build block view equivalent to the contiguous tensors. blockK := cpu.NewTensor(tensor.Shape{4, 1}, []float32{1, 2, 3, 0}) blockV := cpu.NewTensor(tensor.Shape{4, 1}, []float32{10, 20, 30, 0}) view := kvcache.View{ K: blockK, V: blockV, Start: 0, Length: 3, Device: tensor.CPU, } outBlocks := cpu.NewTensor(tensor.Shape{2, 1}, nil) if err := CausalAttentionBlocks(q, []kvcache.View{view}, outBlocks, numHeads, numKVHeads, headDim, startPos); err != nil { t.Fatalf("block attention failed: %v", err) } for i := range outContig.DataFloat32() { if diff := math.Abs(float64(outContig.DataFloat32()[i] - outBlocks.DataFloat32()[i])); diff > 1e-5 { t.Fatalf("mismatch at %d: contiguous=%v blocks=%v", i, outContig.DataFloat32()[i], outBlocks.DataFloat32()[i]) } } } func TestCausalAttentionPackedMatchesBlocks(t *testing.T) { numHeads, numKVHeads, headDim := 4, 2, 8 newTokens := 2 startPos := 4 blockSize := 8 kvDim := numKVHeads * headDim qData := make([]float32, newTokens*numHeads*headDim) for i := range qData { qData[i] = float32(i%7) / 7 } q := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, qData) // total KV length includes past + current total := startPos + newTokens kData := make([]float32, total*kvDim) vData := make([]float32, total*kvDim) for i := range kData { kData[i] = float32((i%17)-8) / 9 vData[i] = float32((i%19)-9) / 10 } views := make([]kvcache.View, 0, (total+blockSize-1)/blockSize) pviews := make([]kvcache.PackedView, 0, (total+blockSize-1)/blockSize) for start := 0; start < total; start += blockSize { length := blockSize if start+length > total { length = total - start } kBlkData := make([]float32, blockSize*kvDim) vBlkData := make([]float32, blockSize*kvDim) copy(kBlkData, kData[start*kvDim:(start+length)*kvDim]) copy(vBlkData, vData[start*kvDim:(start+length)*kvDim]) kBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, kBlkData) vBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, vBlkData) views = append(views, kvcache.View{K: kBlk, V: vBlk, Start: start, Length: length, Device: tensor.CPU}) pk := make([]float32, numKVHeads*blockSize*headDim) pv := make([]float32, numKVHeads*blockSize*headDim) for ti := 0; ti < length; ti++ { baseTok := (start + ti) * kvDim for h := 0; h < numKVHeads; h++ { srcBase := baseTok + h*headDim dstBase := h*(blockSize*headDim) + ti*headDim copy(pk[dstBase:dstBase+headDim], kData[srcBase:srcBase+headDim]) copy(pv[dstBase:dstBase+headDim], vData[srcBase:srcBase+headDim]) } } pviews = append(pviews, kvcache.PackedView{K: pk, V: pv, Start: start, Length: length, BlockSize: blockSize, HeadDim: headDim, NumKVHeads: numKVHeads}) } outBlocks := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil) outPacked := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil) if err := CausalAttentionBlocks(q, views, outBlocks, numHeads, numKVHeads, headDim, startPos); err != nil { t.Fatalf("blocks attention failed: %v", err) } if err := CausalAttentionPackedBlocks(q, pviews, outPacked, numHeads, numKVHeads, headDim, startPos); err != nil { t.Fatalf("packed attention failed: %v", err) } for i := range outBlocks.DataFloat32() { diff := math.Abs(float64(outBlocks.DataFloat32()[i] - outPacked.DataFloat32()[i])) if diff > 1e-5 { t.Fatalf("mismatch at %d: blocks=%v packed=%v", i, outBlocks.DataFloat32()[i], outPacked.DataFloat32()[i]) } } }