| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- package nn
- import (
- "math"
- "testing"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/kvcache"
- "makarna/pkg/tensor"
- )
- func TestCausalAttentionBlocksMatchesContiguous(t *testing.T) {
- numHeads, numKVHeads, headDim := 1, 1, 1
- startPos := 1
- // Two new tokens attending over three total tokens (one past + two new).
- q := cpu.NewTensor(tensor.Shape{2, 1}, []float32{0.5, 1.0})
- kAll := cpu.NewTensor(tensor.Shape{3, 1}, []float32{1, 2, 3})
- vAll := cpu.NewTensor(tensor.Shape{3, 1}, []float32{10, 20, 30})
- outContig := cpu.NewTensor(tensor.Shape{2, 1}, nil)
- if err := CausalAttentionCached(q, kAll, vAll, outContig, numHeads, numKVHeads, headDim, startPos); err != nil {
- t.Fatalf("contiguous attention failed: %v", err)
- }
- // Build block view equivalent to the contiguous tensors.
- blockK := cpu.NewTensor(tensor.Shape{4, 1}, []float32{1, 2, 3, 0})
- blockV := cpu.NewTensor(tensor.Shape{4, 1}, []float32{10, 20, 30, 0})
- view := kvcache.View{
- K: blockK,
- V: blockV,
- Start: 0,
- Length: 3,
- Device: tensor.CPU,
- }
- outBlocks := cpu.NewTensor(tensor.Shape{2, 1}, nil)
- if err := CausalAttentionBlocks(q, []kvcache.View{view}, outBlocks, numHeads, numKVHeads, headDim, startPos); err != nil {
- t.Fatalf("block attention failed: %v", err)
- }
- for i := range outContig.DataFloat32() {
- if diff := math.Abs(float64(outContig.DataFloat32()[i] - outBlocks.DataFloat32()[i])); diff > 1e-5 {
- t.Fatalf("mismatch at %d: contiguous=%v blocks=%v", i, outContig.DataFloat32()[i], outBlocks.DataFloat32()[i])
- }
- }
- }
- func TestCausalAttentionPackedMatchesBlocks(t *testing.T) {
- numHeads, numKVHeads, headDim := 4, 2, 8
- newTokens := 2
- startPos := 4
- blockSize := 8
- kvDim := numKVHeads * headDim
- qData := make([]float32, newTokens*numHeads*headDim)
- for i := range qData {
- qData[i] = float32(i%7) / 7
- }
- q := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, qData)
- // total KV length includes past + current
- total := startPos + newTokens
- kData := make([]float32, total*kvDim)
- vData := make([]float32, total*kvDim)
- for i := range kData {
- kData[i] = float32((i%17)-8) / 9
- vData[i] = float32((i%19)-9) / 10
- }
- views := make([]kvcache.View, 0, (total+blockSize-1)/blockSize)
- pviews := make([]kvcache.PackedView, 0, (total+blockSize-1)/blockSize)
- for start := 0; start < total; start += blockSize {
- length := blockSize
- if start+length > total {
- length = total - start
- }
- kBlkData := make([]float32, blockSize*kvDim)
- vBlkData := make([]float32, blockSize*kvDim)
- copy(kBlkData, kData[start*kvDim:(start+length)*kvDim])
- copy(vBlkData, vData[start*kvDim:(start+length)*kvDim])
- kBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, kBlkData)
- vBlk := cpu.NewTensor(tensor.Shape{blockSize, kvDim}, vBlkData)
- views = append(views, kvcache.View{K: kBlk, V: vBlk, Start: start, Length: length, Device: tensor.CPU})
- pk := make([]float32, numKVHeads*blockSize*headDim)
- pv := make([]float32, numKVHeads*blockSize*headDim)
- for ti := 0; ti < length; ti++ {
- baseTok := (start + ti) * kvDim
- for h := 0; h < numKVHeads; h++ {
- srcBase := baseTok + h*headDim
- dstBase := h*(blockSize*headDim) + ti*headDim
- copy(pk[dstBase:dstBase+headDim], kData[srcBase:srcBase+headDim])
- copy(pv[dstBase:dstBase+headDim], vData[srcBase:srcBase+headDim])
- }
- }
- pviews = append(pviews, kvcache.PackedView{K: pk, V: pv, Start: start, Length: length, BlockSize: blockSize, HeadDim: headDim, NumKVHeads: numKVHeads})
- }
- outBlocks := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil)
- outPacked := cpu.NewTensor(tensor.Shape{newTokens, numHeads * headDim}, nil)
- if err := CausalAttentionBlocks(q, views, outBlocks, numHeads, numKVHeads, headDim, startPos); err != nil {
- t.Fatalf("blocks attention failed: %v", err)
- }
- if err := CausalAttentionPackedBlocks(q, pviews, outPacked, numHeads, numKVHeads, headDim, startPos); err != nil {
- t.Fatalf("packed attention failed: %v", err)
- }
- for i := range outBlocks.DataFloat32() {
- diff := math.Abs(float64(outBlocks.DataFloat32()[i] - outPacked.DataFloat32()[i]))
- if diff > 1e-5 {
- t.Fatalf("mismatch at %d: blocks=%v packed=%v", i, outBlocks.DataFloat32()[i], outPacked.DataFloat32()[i])
- }
- }
- }
|