| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- package nn
- import (
- "fmt"
- "math"
- "sync"
- "makarna/pkg/backend/cpu"
- "makarna/pkg/kvcache"
- )
- // CausalAttentionPackedBlocksBatch computes causal attention for a batch where each
- // row in q/output is an independent sequence (decode-style batch).
- //
- // q: [numTokens, numHeads*headDim] (float32, CPU)
- // views: per-token packed KV views (head-major blocks)
- // output: [numTokens, numHeads*headDim] (float32, CPU)
- // queryPos: per-token absolute query position (startPos for that token).
- //
- // This is optimized for CPU KV caches and uses SIMD-backed Dot/Axpy when available.
- func CausalAttentionPackedBlocksBatch(
- q *cpu.Tensor,
- viewsByToken [][]kvcache.PackedView,
- output *cpu.Tensor,
- numHeads, numKVHeads, headDim int,
- queryPos []int,
- ) error {
- if q == nil || output == nil {
- return fmt.Errorf("nil tensor")
- }
- if numHeads <= 0 || numKVHeads <= 0 || headDim <= 0 {
- return fmt.Errorf("invalid heads/dim: numHeads=%d numKVHeads=%d headDim=%d", numHeads, numKVHeads, headDim)
- }
- if numHeads%numKVHeads != 0 {
- return fmt.Errorf("numHeads %d not divisible by numKVHeads %d", numHeads, numKVHeads)
- }
- qShape := q.Shape()
- outShape := output.Shape()
- if len(qShape) != 2 || len(outShape) != 2 {
- return fmt.Errorf("expected 2D tensors (q=%v out=%v)", qShape, outShape)
- }
- numTokens := qShape[0]
- qStride := qShape[1]
- if numTokens <= 0 {
- return nil
- }
- if outShape[0] != numTokens || outShape[1] != qStride {
- return fmt.Errorf("output shape mismatch: q=%v out=%v", qShape, outShape)
- }
- if qStride != numHeads*headDim {
- return fmt.Errorf("q stride mismatch: got %d want %d (numHeads*headDim)", qStride, numHeads*headDim)
- }
- if len(viewsByToken) != numTokens {
- return fmt.Errorf("viewsByToken len %d != numTokens %d", len(viewsByToken), numTokens)
- }
- if len(queryPos) != numTokens {
- return fmt.Errorf("queryPos len %d != numTokens %d", len(queryPos), numTokens)
- }
- qData := q.DataFloat32()
- outData := output.DataFloat32()
- scale := float32(1.0 / math.Sqrt(float64(headDim)))
- groupSize := numHeads / numKVHeads
- workItems := numTokens * numHeads
- workers := cpu.MaxThreads()
- if workers < 2 || workItems < 2 {
- runPackedBatchWork(qData, outData, viewsByToken, queryPos, numTokens, numHeads, numKVHeads, headDim, groupSize, qStride, scale, 0, workItems)
- return nil
- }
- if workers > workItems {
- workers = workItems
- }
- chunk := (workItems + workers - 1) / workers
- var wg sync.WaitGroup
- wg.Add(workers)
- for w := 0; w < workers; w++ {
- start := w * chunk
- end := start + chunk
- if end > workItems {
- end = workItems
- }
- go func(s, e int) {
- defer wg.Done()
- runPackedBatchWork(qData, outData, viewsByToken, queryPos, numTokens, numHeads, numKVHeads, headDim, groupSize, qStride, scale, s, e)
- }(start, end)
- }
- wg.Wait()
- return nil
- }
- func runPackedBatchWork(
- qData, outData []float32,
- viewsByToken [][]kvcache.PackedView,
- queryPos []int,
- numTokens, numHeads, numKVHeads, headDim, groupSize, qStride int,
- scale float32,
- workStart, workEnd int,
- ) {
- for idx := workStart; idx < workEnd; idx++ {
- tok := idx / numHeads
- head := idx - tok*numHeads
- if tok < 0 || tok >= numTokens {
- continue
- }
- qHeadOffset := head * headDim
- qBase := tok*qStride + qHeadOffset
- if qBase < 0 || qBase+headDim > len(qData) {
- continue
- }
- outBase := tok*qStride + qHeadOffset
- if outBase < 0 || outBase+headDim > len(outData) {
- continue
- }
- qPtr := &qData[qBase]
- outVec := outData[outBase : outBase+headDim]
- outPtr := &outData[outBase]
- clear(outVec)
- kvHead := head / groupSize
- maxKeyPos := queryPos[tok] + 1
- if maxKeyPos <= 0 {
- continue
- }
- m := float32(-math.MaxFloat32)
- l := float32(0)
- for _, pv := range viewsByToken[tok] {
- if pv.Length == 0 || pv.Start >= maxKeyPos {
- continue
- }
- if pv.HeadDim != headDim || pv.NumKVHeads != numKVHeads {
- continue
- }
- blkStride := pv.BlockSize * headDim
- headBase := kvHead * blkStride
- if blkStride <= 0 || headBase < 0 || headBase+blkStride > len(pv.K) || headBase+blkStride > len(pv.V) {
- continue
- }
- viewLimit := pv.Length
- if pv.Start+viewLimit > maxKeyPos {
- viewLimit = maxKeyPos - pv.Start
- }
- if viewLimit <= 0 {
- continue
- }
- kHead := pv.K[headBase : headBase+blkStride]
- vHead := pv.V[headBase : headBase+blkStride]
- for t := 0; t < viewLimit; t++ {
- kPtr := &kHead[t*headDim]
- s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * scale
- vPtr := &vHead[t*headDim]
- if s > m {
- alpha := expf(m - s)
- if l != 0 {
- for i := 0; i < headDim; i++ {
- outVec[i] *= alpha
- }
- l *= alpha
- }
- m = s
- l += 1
- cpu.AxpyPtr(1, vPtr, outPtr, headDim)
- continue
- }
- w := expf(s - m)
- l += w
- cpu.AxpyPtr(w, vPtr, outPtr, headDim)
- }
- }
- if l != 0 {
- inv := 1 / l
- for i := 0; i < headDim; i++ {
- outVec[i] *= inv
- }
- }
- }
- }
|