package nn import ( "math" "math/rand" "testing" "makarna/pkg/backend/cpu" "makarna/pkg/kvcache" "makarna/pkg/tensor" ) func TestCausalAttentionPackedBlocksBatch_MatchesPerToken(t *testing.T) { numTokens := 7 numHeads := 8 numKVHeads := 2 headDim := 16 blockSize := 4 qStride := numHeads * headDim queryPos := []int{0, 3, 5, 1, 7, 2, 9} if len(queryPos) != numTokens { t.Fatalf("test bug: queryPos len %d != numTokens %d", len(queryPos), numTokens) } rng := rand.New(rand.NewSource(1234)) viewsByToken := make([][]kvcache.PackedView, numTokens) for tok := 0; tok < numTokens; tok++ { kvLen := queryPos[tok] + 1 if kvLen <= 0 { t.Fatalf("kvLen must be > 0, got %d", kvLen) } numBlocks := (kvLen + blockSize - 1) / blockSize views := make([]kvcache.PackedView, 0, numBlocks) for b := 0; b < numBlocks; b++ { start := b * blockSize length := blockSize if start+length > kvLen { length = kvLen - start } if length <= 0 { break } blkStride := blockSize * headDim k := make([]float32, numKVHeads*blkStride) v := make([]float32, numKVHeads*blkStride) for i := range k { k[i] = rng.Float32()*2 - 1 v[i] = rng.Float32()*2 - 1 } views = append(views, kvcache.PackedView{ K: k, V: v, Start: start, Length: length, BlockSize: blockSize, HeadDim: headDim, NumKVHeads: numKVHeads, }) } viewsByToken[tok] = views } q := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil) qData := q.DataFloat32() for i := range qData { qData[i] = rng.Float32()*2 - 1 } outBatch := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil) if err := CausalAttentionPackedBlocksBatch(q, viewsByToken, outBatch, numHeads, numKVHeads, headDim, queryPos); err != nil { t.Fatalf("CausalAttentionPackedBlocksBatch: %v", err) } outRef := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil) outRefData := outRef.DataFloat32() for tok := 0; tok < numTokens; tok++ { qRow := cpu.NewTensor(tensor.Shape{1, qStride}, qData[tok*qStride:(tok+1)*qStride]) outRow := cpu.NewTensor(tensor.Shape{1, qStride}, outRefData[tok*qStride:(tok+1)*qStride]) if err := CausalAttentionPackedBlocks(qRow, viewsByToken[tok], outRow, numHeads, numKVHeads, headDim, queryPos[tok]); err != nil { t.Fatalf("CausalAttentionPackedBlocks tok=%d: %v", tok, err) } } got := outBatch.DataFloat32() want := outRef.DataFloat32() if len(got) != len(want) { t.Fatalf("length mismatch: got %d want %d", len(got), len(want)) } const tol = 1e-4 for i := range got { if diff := math.Abs(float64(got[i] - want[i])); diff > tol { t.Fatalf("mismatch at %d: got=%g want=%g diff=%g", i, got[i], want[i], diff) } } }