| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108 |
- package nn
- import (
- "math"
- "sync"
- "makarna/pkg/backend/cpu"
- )
- func CausalAttentionCachedKV(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDimK, headDimV, startPos int) error {
- newTokens := q.Shape()[0]
- totalSeqLen := k.Shape()[0]
- qData := q.DataFloat32()
- kData := k.DataFloat32()
- vData := v.DataFloat32()
- outData := output.DataFloat32()
- scale := 1.0 / math.Sqrt(float64(headDimK))
- groupSize := numHeads / numKVHeads
- workers := cpu.MaxThreads()
- if workers < 2 || numHeads < 2 {
- runCausalCachedHeadsKV(qData, kData, vData, outData, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos, scale, 0, numHeads)
- return nil
- }
- chunk := (numHeads + workers - 1) / workers
- var wg sync.WaitGroup
- for start := 0; start < numHeads; start += chunk {
- end := start + chunk
- if end > numHeads {
- end = numHeads
- }
- wg.Add(1)
- go func(s, e int) {
- defer wg.Done()
- runCausalCachedHeadsKV(qData, kData, vData, outData, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos, scale, s, e)
- }(start, end)
- }
- wg.Wait()
- return nil
- }
- func runCausalCachedHeadsKV(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDimK, headDimV, groupSize, newTokens, totalSeqLen, startPos int, scale float64, hStart, hEnd int) {
- strideQ := numHeads * headDimK
- strideK := numKVHeads * headDimK
- strideV := numKVHeads * headDimV
- strideOut := numHeads * headDimV
- for h := hStart; h < hEnd; h++ {
- qHeadOffset := h * headDimK
- outHeadOffset := h * headDimV
- kvHead := h / groupSize
- kHeadOffset := kvHead * headDimK
- vHeadOffset := kvHead * headDimV
- for qi := 0; qi < newTokens; qi++ {
- maxKeyPos := startPos + qi + 1
- if maxKeyPos > totalSeqLen {
- maxKeyPos = totalSeqLen
- }
- qBase := qi*strideQ + qHeadOffset
- qPtr := &qData[qBase]
- outBase := qi*strideOut + outHeadOffset
- outVec := outData[outBase : outBase+headDimV]
- outPtr := &outData[outBase]
- clear(outVec)
- m := float32(-math.MaxFloat32)
- l := float32(0)
- for ti := 0; ti < maxKeyPos; ti++ {
- kBase := ti*strideK + kHeadOffset
- kPtr := &kData[kBase]
- s := cpu.DotFloat32Ptr(qPtr, kPtr, headDimK) * float32(scale)
- vBase := ti*strideV + vHeadOffset
- vPtr := &vData[vBase]
- if s > m {
- alpha := expf(m - s)
- if l != 0 {
- for i := 0; i < headDimV; i++ {
- outVec[i] *= alpha
- }
- l *= alpha
- }
- m = s
- l += 1
- cpu.AxpyPtr(1, vPtr, outPtr, headDimV)
- continue
- }
- w := expf(s - m)
- l += w
- cpu.AxpyPtr(w, vPtr, outPtr, headDimV)
- }
- if l != 0 {
- inv := 1 / l
- for i := 0; i < headDimV; i++ {
- outVec[i] *= inv
- }
- }
- }
- }
- }
|