package nn import ( "fmt" "math" "sync" "makarna/pkg/backend/cpu" "makarna/pkg/kvcache" ) // CausalAttentionPackedBlocksBatch computes causal attention for a batch where each // row in q/output is an independent sequence (decode-style batch). // // q: [numTokens, numHeads*headDim] (float32, CPU) // views: per-token packed KV views (head-major blocks) // output: [numTokens, numHeads*headDim] (float32, CPU) // queryPos: per-token absolute query position (startPos for that token). // // This is optimized for CPU KV caches and uses SIMD-backed Dot/Axpy when available. func CausalAttentionPackedBlocksBatch( q *cpu.Tensor, viewsByToken [][]kvcache.PackedView, output *cpu.Tensor, numHeads, numKVHeads, headDim int, queryPos []int, ) error { if q == nil || output == nil { return fmt.Errorf("nil tensor") } if numHeads <= 0 || numKVHeads <= 0 || headDim <= 0 { return fmt.Errorf("invalid heads/dim: numHeads=%d numKVHeads=%d headDim=%d", numHeads, numKVHeads, headDim) } if numHeads%numKVHeads != 0 { return fmt.Errorf("numHeads %d not divisible by numKVHeads %d", numHeads, numKVHeads) } qShape := q.Shape() outShape := output.Shape() if len(qShape) != 2 || len(outShape) != 2 { return fmt.Errorf("expected 2D tensors (q=%v out=%v)", qShape, outShape) } numTokens := qShape[0] qStride := qShape[1] if numTokens <= 0 { return nil } if outShape[0] != numTokens || outShape[1] != qStride { return fmt.Errorf("output shape mismatch: q=%v out=%v", qShape, outShape) } if qStride != numHeads*headDim { return fmt.Errorf("q stride mismatch: got %d want %d (numHeads*headDim)", qStride, numHeads*headDim) } if len(viewsByToken) != numTokens { return fmt.Errorf("viewsByToken len %d != numTokens %d", len(viewsByToken), numTokens) } if len(queryPos) != numTokens { return fmt.Errorf("queryPos len %d != numTokens %d", len(queryPos), numTokens) } qData := q.DataFloat32() outData := output.DataFloat32() scale := float32(1.0 / math.Sqrt(float64(headDim))) groupSize := numHeads / numKVHeads workItems := numTokens * numHeads workers := cpu.MaxThreads() if workers < 2 || workItems < 2 { runPackedBatchWork(qData, outData, viewsByToken, queryPos, numTokens, numHeads, numKVHeads, headDim, groupSize, qStride, scale, 0, workItems) return nil } if workers > workItems { workers = workItems } chunk := (workItems + workers - 1) / workers var wg sync.WaitGroup wg.Add(workers) for w := 0; w < workers; w++ { start := w * chunk end := start + chunk if end > workItems { end = workItems } go func(s, e int) { defer wg.Done() runPackedBatchWork(qData, outData, viewsByToken, queryPos, numTokens, numHeads, numKVHeads, headDim, groupSize, qStride, scale, s, e) }(start, end) } wg.Wait() return nil } func runPackedBatchWork( qData, outData []float32, viewsByToken [][]kvcache.PackedView, queryPos []int, numTokens, numHeads, numKVHeads, headDim, groupSize, qStride int, scale float32, workStart, workEnd int, ) { for idx := workStart; idx < workEnd; idx++ { tok := idx / numHeads head := idx - tok*numHeads if tok < 0 || tok >= numTokens { continue } qHeadOffset := head * headDim qBase := tok*qStride + qHeadOffset if qBase < 0 || qBase+headDim > len(qData) { continue } outBase := tok*qStride + qHeadOffset if outBase < 0 || outBase+headDim > len(outData) { continue } qPtr := &qData[qBase] outVec := outData[outBase : outBase+headDim] outPtr := &outData[outBase] clear(outVec) kvHead := head / groupSize maxKeyPos := queryPos[tok] + 1 if maxKeyPos <= 0 { continue } m := float32(-math.MaxFloat32) l := float32(0) for _, pv := range viewsByToken[tok] { if pv.Length == 0 || pv.Start >= maxKeyPos { continue } if pv.HeadDim != headDim || pv.NumKVHeads != numKVHeads { continue } blkStride := pv.BlockSize * headDim headBase := kvHead * blkStride if blkStride <= 0 || headBase < 0 || headBase+blkStride > len(pv.K) || headBase+blkStride > len(pv.V) { continue } viewLimit := pv.Length if pv.Start+viewLimit > maxKeyPos { viewLimit = maxKeyPos - pv.Start } if viewLimit <= 0 { continue } kHead := pv.K[headBase : headBase+blkStride] vHead := pv.V[headBase : headBase+blkStride] for t := 0; t < viewLimit; t++ { kPtr := &kHead[t*headDim] s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * scale vPtr := &vHead[t*headDim] if s > m { alpha := expf(m - s) if l != 0 { for i := 0; i < headDim; i++ { outVec[i] *= alpha } l *= alpha } m = s l += 1 cpu.AxpyPtr(1, vPtr, outPtr, headDim) continue } w := expf(s - m) l += w cpu.AxpyPtr(w, vPtr, outPtr, headDim) } } if l != 0 { inv := 1 / l for i := 0; i < headDim; i++ { outVec[i] *= inv } } } }