| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- package nn
- import (
- "math"
- "math/rand"
- "testing"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/kvcache"
- "makarna/pkg/tensor"
- )
- func TestCausalAttentionPackedBlocksBatch_MatchesPerToken(t *testing.T) {
- numTokens := 7
- numHeads := 8
- numKVHeads := 2
- headDim := 16
- blockSize := 4
- qStride := numHeads * headDim
- queryPos := []int{0, 3, 5, 1, 7, 2, 9}
- if len(queryPos) != numTokens {
- t.Fatalf("test bug: queryPos len %d != numTokens %d", len(queryPos), numTokens)
- }
- rng := rand.New(rand.NewSource(1234))
- viewsByToken := make([][]kvcache.PackedView, numTokens)
- for tok := 0; tok < numTokens; tok++ {
- kvLen := queryPos[tok] + 1
- if kvLen <= 0 {
- t.Fatalf("kvLen must be > 0, got %d", kvLen)
- }
- numBlocks := (kvLen + blockSize - 1) / blockSize
- views := make([]kvcache.PackedView, 0, numBlocks)
- for b := 0; b < numBlocks; b++ {
- start := b * blockSize
- length := blockSize
- if start+length > kvLen {
- length = kvLen - start
- }
- if length <= 0 {
- break
- }
- blkStride := blockSize * headDim
- k := make([]float32, numKVHeads*blkStride)
- v := make([]float32, numKVHeads*blkStride)
- for i := range k {
- k[i] = rng.Float32()*2 - 1
- v[i] = rng.Float32()*2 - 1
- }
- views = append(views, kvcache.PackedView{
- K: k,
- V: v,
- Start: start,
- Length: length,
- BlockSize: blockSize,
- HeadDim: headDim,
- NumKVHeads: numKVHeads,
- })
- }
- viewsByToken[tok] = views
- }
- q := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil)
- qData := q.DataFloat32()
- for i := range qData {
- qData[i] = rng.Float32()*2 - 1
- }
- outBatch := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil)
- if err := CausalAttentionPackedBlocksBatch(q, viewsByToken, outBatch, numHeads, numKVHeads, headDim, queryPos); err != nil {
- t.Fatalf("CausalAttentionPackedBlocksBatch: %v", err)
- }
- outRef := cpu.NewTensor(tensor.Shape{numTokens, qStride}, nil)
- outRefData := outRef.DataFloat32()
- for tok := 0; tok < numTokens; tok++ {
- qRow := cpu.NewTensor(tensor.Shape{1, qStride}, qData[tok*qStride:(tok+1)*qStride])
- outRow := cpu.NewTensor(tensor.Shape{1, qStride}, outRefData[tok*qStride:(tok+1)*qStride])
- if err := CausalAttentionPackedBlocks(qRow, viewsByToken[tok], outRow, numHeads, numKVHeads, headDim, queryPos[tok]); err != nil {
- t.Fatalf("CausalAttentionPackedBlocks tok=%d: %v", tok, err)
- }
- }
- got := outBatch.DataFloat32()
- want := outRef.DataFloat32()
- if len(got) != len(want) {
- t.Fatalf("length mismatch: got %d want %d", len(got), len(want))
- }
- const tol = 1e-4
- for i := range got {
- if diff := math.Abs(float64(got[i] - want[i])); diff > tol {
- t.Fatalf("mismatch at %d: got=%g want=%g diff=%g", i, got[i], want[i], diff)
- }
- }
- }
|