attention_cached.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. package nn
  2. import (
  3. "fmt"
  4. "math"
  5. "sort"
  6. "sync"
  7. "makarna/pkg/backend/cpu"
  8. "makarna/pkg/backend/cuda"
  9. "makarna/pkg/kvcache"
  10. "makarna/pkg/tensor"
  11. )
  12. var useFastExp = true
  13. func float16BitsToFloat32(bits uint16) float32 {
  14. sign := uint32(bits&0x8000) << 16
  15. exp := int32((bits & 0x7C00) >> 10)
  16. mant := uint32(bits & 0x03FF)
  17. if exp == 0 {
  18. if mant == 0 {
  19. return math.Float32frombits(sign)
  20. }
  21. for mant&0x0400 == 0 {
  22. mant <<= 1
  23. exp--
  24. }
  25. exp++
  26. mant &= 0x03FF
  27. } else if exp == 0x1F {
  28. if mant == 0 {
  29. return math.Float32frombits(sign | 0x7F800000)
  30. }
  31. return math.Float32frombits(sign | 0x7FC00000)
  32. }
  33. exp = exp + (127 - 15)
  34. return math.Float32frombits(sign | (uint32(exp) << 23) | (mant << 13))
  35. }
  36. func bfloat16BitsToFloat32(bits uint16) float32 {
  37. return math.Float32frombits(uint32(bits) << 16)
  38. }
  39. func expf(x float32) float32 {
  40. if useFastExp {
  41. // Clamp to a reasonable range for stability.
  42. // For softmax weights, very negative values underflow to ~0 anyway.
  43. if x < -20 {
  44. x = -20
  45. } else if x > 10 {
  46. x = 10
  47. }
  48. // Schraudolph-style fast exp approximation.
  49. // Good tradeoff for softmax weights; much faster than math.Exp.
  50. const a = 12102203.0 // (1<<23)/ln(2)
  51. const b = 1065353216.0
  52. return math.Float32frombits(uint32(float32(a)*x + float32(b)))
  53. }
  54. return float32(math.Exp(float64(x)))
  55. }
  56. type viewData struct {
  57. kData []float32
  58. vData []float32
  59. start int
  60. length int
  61. }
  62. // CausalAttentionCached computes causal attention using cached K/V
  63. // Q: [newTokens, numHeads * headDim] - query for new tokens only
  64. // K: [totalSeqLen, numKVHeads * headDim] - full K history including current
  65. // V: [totalSeqLen, numKVHeads * headDim] - full V history including current
  66. // Output: [newTokens, numHeads * headDim]
  67. // startPos: position of first new token in sequence
  68. func CausalAttentionCached(q, k, v, output *cpu.Tensor, numHeads, numKVHeads, headDim, startPos int) error {
  69. newTokens := q.Shape()[0]
  70. totalSeqLen := k.Shape()[0]
  71. qData := q.DataFloat32()
  72. kData := k.DataFloat32()
  73. vData := v.DataFloat32()
  74. outData := output.DataFloat32()
  75. scale := 1.0 / math.Sqrt(float64(headDim))
  76. groupSize := numHeads / numKVHeads
  77. workers := cpu.MaxThreads()
  78. if workers < 2 || numHeads < 2 {
  79. runCausalCachedHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos, scale, 0, numHeads)
  80. return nil
  81. }
  82. chunk := (numHeads + workers - 1) / workers
  83. var wg sync.WaitGroup
  84. for start := 0; start < numHeads; start += chunk {
  85. end := start + chunk
  86. if end > numHeads {
  87. end = numHeads
  88. }
  89. wg.Add(1)
  90. go func(s, e int) {
  91. defer wg.Done()
  92. runCausalCachedHeads(qData, kData, vData, outData, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos, scale, s, e)
  93. }(start, end)
  94. }
  95. wg.Wait()
  96. return nil
  97. }
  98. // CausalAttentionPackedBlocks computes causal attention over packed KV views.
  99. // Packed layout is head-major: [kvHead][tokenWithinBlock][headDim] as a flat slice.
  100. // This avoids kvDim-stride traversal and is a fast CPU path.
  101. func CausalAttentionPackedBlocks(
  102. q *cpu.Tensor,
  103. views []kvcache.PackedView,
  104. output *cpu.Tensor,
  105. numHeads, numKVHeads, headDim, startPos int,
  106. ) error {
  107. newTokens := q.Shape()[0]
  108. qData := q.DataFloat32()
  109. outData := output.DataFloat32()
  110. scale := 1.0 / math.Sqrt(float64(headDim))
  111. groupSize := numHeads / numKVHeads
  112. // Sort to guarantee increasing start positions.
  113. sort.Slice(views, func(i, j int) bool {
  114. return views[i].Start < views[j].Start
  115. })
  116. workers := cpu.MaxThreads()
  117. if workers < 2 || numHeads < 2 {
  118. runCausalPackedHeads(qData, outData, views, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, 0, numHeads)
  119. return nil
  120. }
  121. chunk := (numHeads + workers - 1) / workers
  122. var wg sync.WaitGroup
  123. for start := 0; start < numHeads; start += chunk {
  124. end := start + chunk
  125. if end > numHeads {
  126. end = numHeads
  127. }
  128. wg.Add(1)
  129. go func(s, e int) {
  130. defer wg.Done()
  131. runCausalPackedHeads(qData, outData, views, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, s, e)
  132. }(start, end)
  133. }
  134. wg.Wait()
  135. return nil
  136. }
  137. func runCausalCachedHeads(qData, kData, vData, outData []float32, numHeads, numKVHeads, headDim, groupSize, newTokens, totalSeqLen, startPos int, scale float64, hStart, hEnd int) {
  138. strideQ := numHeads * headDim
  139. strideKV := numKVHeads * headDim
  140. for h := hStart; h < hEnd; h++ {
  141. qHeadOffset := h * headDim
  142. kvHead := h / groupSize
  143. kvHeadOffset := kvHead * headDim
  144. for qi := 0; qi < newTokens; qi++ {
  145. maxKeyPos := startPos + qi + 1
  146. if maxKeyPos > totalSeqLen {
  147. maxKeyPos = totalSeqLen
  148. }
  149. qBase := qi*strideQ + qHeadOffset
  150. qPtr := &qData[qBase]
  151. outBase := qi*strideQ + qHeadOffset
  152. outVec := outData[outBase : outBase+headDim]
  153. outPtr := &outData[outBase]
  154. clear(outVec)
  155. m := float32(-math.MaxFloat32)
  156. l := float32(0)
  157. for ti := 0; ti < maxKeyPos; ti++ {
  158. kBase := ti*strideKV + kvHeadOffset
  159. kPtr := &kData[kBase]
  160. s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * float32(scale)
  161. vBase := ti*strideKV + kvHeadOffset
  162. vPtr := &vData[vBase]
  163. if s > m {
  164. alpha := expf(m - s)
  165. if l != 0 {
  166. for i := 0; i < headDim; i++ {
  167. outVec[i] *= alpha
  168. }
  169. l *= alpha
  170. }
  171. m = s
  172. l += 1
  173. cpu.AxpyPtr(1, vPtr, outPtr, headDim)
  174. continue
  175. }
  176. w := expf(s - m)
  177. l += w
  178. cpu.AxpyPtr(w, vPtr, outPtr, headDim)
  179. }
  180. if l != 0 {
  181. inv := 1 / l
  182. for i := 0; i < headDim; i++ {
  183. outVec[i] *= inv
  184. }
  185. }
  186. }
  187. }
  188. }
  189. func runCausalPackedHeads(qData, outData []float32, views []kvcache.PackedView, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos int, scale float64, hStart, hEnd int) {
  190. strideQ := numHeads * headDim
  191. for h := hStart; h < hEnd; h++ {
  192. qHeadOffset := h * headDim
  193. kvHead := h / groupSize
  194. for qi := 0; qi < newTokens; qi++ {
  195. maxKeyPos := startPos + qi + 1
  196. qBase := qi*strideQ + qHeadOffset
  197. qPtr := &qData[qBase]
  198. outBase := qi*strideQ + qHeadOffset
  199. outVec := outData[outBase : outBase+headDim]
  200. outPtr := &outData[outBase]
  201. clear(outVec)
  202. m := float32(-math.MaxFloat32)
  203. l := float32(0)
  204. for _, pv := range views {
  205. if pv.Length == 0 || pv.Start >= maxKeyPos {
  206. continue
  207. }
  208. if pv.HeadDim != headDim || pv.NumKVHeads != numKVHeads {
  209. continue
  210. }
  211. blkStride := pv.BlockSize * headDim
  212. headBase := kvHead * blkStride
  213. if headBase+blkStride > len(pv.K) || headBase+blkStride > len(pv.V) {
  214. continue
  215. }
  216. viewLimit := pv.Length
  217. if pv.Start+viewLimit > maxKeyPos {
  218. viewLimit = maxKeyPos - pv.Start
  219. }
  220. kHead := pv.K[headBase : headBase+blkStride]
  221. vHead := pv.V[headBase : headBase+blkStride]
  222. for t := 0; t < viewLimit; t++ {
  223. kPtr := &kHead[t*headDim]
  224. s := cpu.DotFloat32Ptr(qPtr, kPtr, headDim) * float32(scale)
  225. vPtr := &vHead[t*headDim]
  226. if s > m {
  227. alpha := expf(m - s)
  228. if l != 0 {
  229. for i := 0; i < headDim; i++ {
  230. outVec[i] *= alpha
  231. }
  232. l *= alpha
  233. }
  234. m = s
  235. l += 1
  236. cpu.AxpyPtr(1, vPtr, outPtr, headDim)
  237. continue
  238. }
  239. w := expf(s - m)
  240. l += w
  241. cpu.AxpyPtr(w, vPtr, outPtr, headDim)
  242. }
  243. }
  244. if l != 0 {
  245. inv := 1 / l
  246. for i := 0; i < headDim; i++ {
  247. outVec[i] *= inv
  248. }
  249. }
  250. }
  251. }
  252. }
  253. // CausalAttentionBlocks computes attention directly over KV block views without
  254. // materializing a contiguous history tensor. startPos is the absolute position
  255. // of the first new token (current cache length before the append).
  256. func CausalAttentionBlocks(
  257. q *cpu.Tensor,
  258. views []kvcache.View,
  259. output *cpu.Tensor,
  260. numHeads, numKVHeads, headDim, startPos int,
  261. ) error {
  262. newTokens := q.Shape()[0]
  263. qData := q.DataFloat32()
  264. outData := output.DataFloat32()
  265. scale := 1.0 / math.Sqrt(float64(headDim))
  266. groupSize := numHeads / numKVHeads
  267. // Pre-extract data from all views (handles CPU and GPU tensors)
  268. viewsData := make([]viewData, len(views))
  269. for i, v := range views {
  270. if v.Length == 0 {
  271. continue
  272. }
  273. kData, err := tensorToFloat32(v.K)
  274. if err != nil {
  275. return fmt.Errorf("failed to get K data from view: %w", err)
  276. }
  277. vData, err := tensorToFloat32(v.V)
  278. if err != nil {
  279. return fmt.Errorf("failed to get V data from view: %w", err)
  280. }
  281. viewsData[i] = viewData{
  282. kData: kData,
  283. vData: vData,
  284. start: v.Start,
  285. length: v.Length,
  286. }
  287. }
  288. sort.Slice(viewsData, func(i, j int) bool {
  289. return viewsData[i].start < viewsData[j].start
  290. })
  291. workers := cpu.MaxThreads()
  292. if workers < 2 || numHeads < 2 {
  293. runCausalBlockHeads(qData, outData, viewsData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, 0, numHeads)
  294. return nil
  295. }
  296. chunk := (numHeads + workers - 1) / workers
  297. var wg sync.WaitGroup
  298. for start := 0; start < numHeads; start += chunk {
  299. end := start + chunk
  300. if end > numHeads {
  301. end = numHeads
  302. }
  303. wg.Add(1)
  304. go func(s, e int) {
  305. defer wg.Done()
  306. runCausalBlockHeads(qData, outData, viewsData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos, scale, s, e)
  307. }(start, end)
  308. }
  309. wg.Wait()
  310. return nil
  311. }
  312. func runCausalBlockHeads(qData, outData []float32, viewsData []viewData, numHeads, numKVHeads, headDim, groupSize, newTokens, startPos int, scale float64, hStart, hEnd int) {
  313. strideQ := numHeads * headDim
  314. strideKV := numKVHeads * headDim
  315. for h := hStart; h < hEnd; h++ {
  316. qHeadOffset := h * headDim
  317. kvHead := h / groupSize
  318. kvHeadOffset := kvHead * headDim
  319. for qi := 0; qi < newTokens; qi++ {
  320. maxKeyPos := startPos + qi + 1
  321. qBase := qi*strideQ + qHeadOffset
  322. qVec := qData[qBase : qBase+headDim]
  323. outBase := qi*strideQ + qHeadOffset
  324. outVec := outData[outBase : outBase+headDim]
  325. clear(outVec)
  326. m := float32(-math.MaxFloat32)
  327. l := float32(0)
  328. for _, vd := range viewsData {
  329. if vd.start >= maxKeyPos || vd.length == 0 {
  330. continue
  331. }
  332. viewLimit := vd.length
  333. if vd.start+viewLimit > maxKeyPos {
  334. viewLimit = maxKeyPos - vd.start
  335. }
  336. for local := 0; local < viewLimit; local++ {
  337. kvIdx := local*strideKV + kvHeadOffset
  338. kVec := vd.kData[kvIdx : kvIdx+headDim]
  339. s := cpu.DotFloat32(qVec, kVec) * float32(scale)
  340. vVec := vd.vData[kvIdx : kvIdx+headDim]
  341. if s > m {
  342. alpha := expf(m - s)
  343. if l != 0 {
  344. for i := 0; i < headDim; i++ {
  345. outVec[i] *= alpha
  346. }
  347. l *= alpha
  348. }
  349. m = s
  350. l += 1
  351. cpu.Axpy(1, vVec, outVec)
  352. continue
  353. }
  354. w := expf(s - m)
  355. l += w
  356. cpu.Axpy(w, vVec, outVec)
  357. }
  358. }
  359. if l != 0 {
  360. inv := 1 / l
  361. for i := 0; i < headDim; i++ {
  362. outVec[i] *= inv
  363. }
  364. }
  365. }
  366. }
  367. }
  368. // tensorToFloat32 extracts float32 data from a tensor, handling both CPU and CUDA tensors.
  369. func tensorToFloat32(t tensor.Tensor) ([]float32, error) {
  370. switch tt := t.(type) {
  371. case *cpu.Tensor:
  372. switch tt.DType() {
  373. case tensor.Float32:
  374. return tt.DataFloat32(), nil
  375. case tensor.Float16:
  376. in := tt.DataUint16()
  377. out := make([]float32, len(in))
  378. for i := range in {
  379. out[i] = float16BitsToFloat32(in[i])
  380. }
  381. return out, nil
  382. case tensor.BFloat16:
  383. in := tt.DataUint16()
  384. out := make([]float32, len(in))
  385. for i := range in {
  386. out[i] = bfloat16BitsToFloat32(in[i])
  387. }
  388. return out, nil
  389. default:
  390. return nil, fmt.Errorf("unsupported CPU tensor dtype: %v", tt.DType())
  391. }
  392. case *cuda.Tensor:
  393. data := make([]float32, t.Shape().NumElements())
  394. if err := tt.CopyToHost(data); err != nil {
  395. return nil, err
  396. }
  397. return data, nil
  398. default:
  399. return nil, fmt.Errorf("unsupported tensor type: %T", t)
  400. }
  401. }
  402. func cpuDevice() tensor.DeviceType { return tensor.CPU }