harness_cuda_test.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638
  1. //go:build cuda
  2. package tests
  3. import (
  4. "math"
  5. "math/rand"
  6. "testing"
  7. "unsafe"
  8. "makarna/pkg/backend/cpu"
  9. cpunn "makarna/pkg/backend/cpu/nn"
  10. "makarna/pkg/backend/cuda"
  11. "makarna/pkg/quant"
  12. "makarna/pkg/tensor"
  13. )
  14. func almostEq(a, b, atol, rtol float32) bool {
  15. d := a - b
  16. if d < 0 {
  17. d = -d
  18. }
  19. thr := atol + rtol*float32(math.Abs(float64(b)))
  20. return d <= thr
  21. }
  22. func assertAllClose(t *testing.T, name string, got, want []float32, atol, rtol float32) {
  23. if len(got) != len(want) {
  24. t.Fatalf("%s len mismatch got=%d want=%d", name, len(got), len(want))
  25. }
  26. for i := range got {
  27. if !almostEq(got[i], want[i], atol, rtol) {
  28. t.Fatalf("%s mismatch at %d: got=%f want=%f (atol=%g rtol=%g)", name, i, got[i], want[i], atol, rtol)
  29. }
  30. }
  31. }
  32. func TestHarness_CUDA_DequantMatchesCPU(t *testing.T) {
  33. gpu := 0
  34. if !cuda.Available() {
  35. t.Skip("cuda not available")
  36. }
  37. cases := []struct {
  38. name string
  39. seed int64
  40. scale float32
  41. }{
  42. {name: "small", seed: 10, scale: 0.01},
  43. {name: "medium", seed: 11, scale: 1.0},
  44. {name: "large", seed: 12, scale: 50.0},
  45. }
  46. for _, tc := range cases {
  47. t.Run(tc.name, func(t *testing.T) {
  48. r := rand.New(rand.NewSource(tc.seed))
  49. inp := make([]float32, 256)
  50. for i := range inp {
  51. inp[i] = (r.Float32()*2 - 1) * tc.scale
  52. }
  53. q2 := quant.QuantizeQ2K(inp)
  54. q3 := quant.QuantizeQ3K(inp)
  55. q4 := quant.QuantizeQ4K(inp)
  56. q6 := quant.QuantizeQ6K(inp)
  57. q8 := quant.QuantizeQ8K(inp)
  58. ref2 := make([]float32, 256)
  59. ref3 := make([]float32, 256)
  60. ref4 := make([]float32, 256)
  61. ref6 := make([]float32, 256)
  62. ref8 := make([]float32, 256)
  63. tensor.DequantizeQ2_K((*tensor.BlockQ2_K)(unsafe.Pointer(&q2[0])), ref2)
  64. tensor.DequantizeQ3_K((*tensor.BlockQ3_K)(unsafe.Pointer(&q3[0])), ref3)
  65. tensor.DequantizeQ4_K((*tensor.BlockQ4_K)(unsafe.Pointer(&q4[0])), ref4)
  66. tensor.DequantizeQ6_K((*tensor.BlockQ6_K)(unsafe.Pointer(&q6[0])), ref6)
  67. tensor.DequantizeQ8_K((*tensor.BlockQ8_K)(unsafe.Pointer(&q8[0])), ref8)
  68. // Allocate output
  69. out, err := cuda.NewTensor(tensor.Shape{256}, tensor.Float32, gpu)
  70. if err != nil {
  71. t.Fatalf("new out: %v", err)
  72. }
  73. defer out.Free()
  74. // Q8
  75. devQ8, err := cuda.UploadQ8K(q8, 1, gpu)
  76. if err != nil {
  77. t.Fatalf("upload q8: %v", err)
  78. }
  79. defer cuda.FreeDevicePtr(devQ8)
  80. if err := cuda.DequantQ8K(devQ8, out.Data().(unsafe.Pointer), 1, gpu); err != nil {
  81. t.Fatalf("dequant q8: %v", err)
  82. }
  83. h8 := make([]float32, 256)
  84. if err := out.CopyToHost(h8); err != nil {
  85. t.Fatalf("copy q8: %v", err)
  86. }
  87. assertAllClose(t, "q8k", h8, ref8, 1e-3, 1e-3)
  88. // Q4
  89. devQ4, err := cuda.UploadQ4K(q4, 1, gpu)
  90. if err != nil {
  91. t.Fatalf("upload q4: %v", err)
  92. }
  93. defer cuda.FreeDevicePtr(devQ4)
  94. if err := cuda.DequantQ4K(devQ4, out.Data().(unsafe.Pointer), 1, gpu); err != nil {
  95. t.Fatalf("dequant q4: %v", err)
  96. }
  97. h4 := make([]float32, 256)
  98. if err := out.CopyToHost(h4); err != nil {
  99. t.Fatalf("copy q4: %v", err)
  100. }
  101. assertAllClose(t, "q4k", h4, ref4, 1e-2, 1e-2)
  102. // Q6
  103. devQ6, err := cuda.UploadQ6K(q6, 1, gpu)
  104. if err != nil {
  105. t.Fatalf("upload q6: %v", err)
  106. }
  107. defer cuda.FreeDevicePtr(devQ6)
  108. if err := cuda.DequantQ6K(devQ6, out.Data().(unsafe.Pointer), 1, gpu); err != nil {
  109. t.Fatalf("dequant q6: %v", err)
  110. }
  111. h6 := make([]float32, 256)
  112. if err := out.CopyToHost(h6); err != nil {
  113. t.Fatalf("copy q6: %v", err)
  114. }
  115. assertAllClose(t, "q6k", h6, ref6, 1e-2, 1e-2)
  116. // Q3
  117. devQ3, err := cuda.UploadQ3K(q3, 1, gpu)
  118. if err != nil {
  119. t.Fatalf("upload q3: %v", err)
  120. }
  121. defer cuda.FreeDevicePtr(devQ3)
  122. if err := cuda.DequantQ3K(devQ3, out.Data().(unsafe.Pointer), 1, gpu); err != nil {
  123. t.Fatalf("dequant q3: %v", err)
  124. }
  125. h3 := make([]float32, 256)
  126. if err := out.CopyToHost(h3); err != nil {
  127. t.Fatalf("copy q3: %v", err)
  128. }
  129. assertAllClose(t, "q3k", h3, ref3, 2e-2, 2e-2)
  130. // Q2
  131. devQ2, err := cuda.UploadQ2K(q2, 1, gpu)
  132. if err != nil {
  133. t.Fatalf("upload q2: %v", err)
  134. }
  135. defer cuda.FreeDevicePtr(devQ2)
  136. if err := cuda.DequantQ2K(devQ2, out.Data().(unsafe.Pointer), 1, gpu); err != nil {
  137. t.Fatalf("dequant q2: %v", err)
  138. }
  139. h2 := make([]float32, 256)
  140. if err := out.CopyToHost(h2); err != nil {
  141. t.Fatalf("copy q2: %v", err)
  142. }
  143. assertAllClose(t, "q2k", h2, ref2, 5e-2, 5e-2)
  144. })
  145. }
  146. }
  147. func TestHarness_CUDA_FusedMatMulMatchesCPUReference(t *testing.T) {
  148. gpu := 0
  149. if !cuda.Available() {
  150. t.Skip("cuda not available")
  151. }
  152. // Keep small M,N but K must be 256-multiple for K-quants
  153. M, K, N := 3, 256, 4
  154. r := rand.New(rand.NewSource(999))
  155. // CPU inputs
  156. Ahost := make([]float32, M*K)
  157. Bhost := make([]float32, N*K)
  158. for i := range Ahost {
  159. Ahost[i] = r.Float32()*2 - 1
  160. }
  161. for i := range Bhost {
  162. Bhost[i] = r.Float32()*2 - 1
  163. }
  164. // CPU reference: C = A @ B^T
  165. ref := make([]float32, M*N)
  166. for m := 0; m < M; m++ {
  167. for n := 0; n < N; n++ {
  168. var s float32
  169. for k := 0; k < K; k++ {
  170. s += Ahost[m*K+k] * Bhost[n*K+k]
  171. }
  172. ref[m*N+n] = s
  173. }
  174. }
  175. // Upload A to GPU
  176. Adev, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float32, gpu)
  177. if err != nil {
  178. t.Fatalf("Adev: %v", err)
  179. }
  180. defer Adev.Free()
  181. if err := Adev.CopyFrom(Ahost); err != nil {
  182. t.Fatalf("copy A: %v", err)
  183. }
  184. Cdev, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
  185. if err != nil {
  186. t.Fatalf("Cdev: %v", err)
  187. }
  188. defer Cdev.Free()
  189. t.Run("q8k", func(t *testing.T) {
  190. q := quant.QuantizeQ8K(Bhost)
  191. devB, err := cuda.UploadQ8K(q, N*(K/256), gpu)
  192. if err != nil {
  193. t.Fatalf("upload: %v", err)
  194. }
  195. defer cuda.FreeDevicePtr(devB)
  196. if err := cuda.MatMulQ8K(Adev.Data().(unsafe.Pointer), devB, Cdev.Data().(unsafe.Pointer), M, K, N, gpu); err != nil {
  197. t.Fatalf("matmul: %v", err)
  198. }
  199. h := make([]float32, M*N)
  200. _ = cuda.Synchronize(gpu)
  201. if err := Cdev.CopyToHost(h); err != nil {
  202. t.Fatalf("copy: %v", err)
  203. }
  204. assertAllClose(t, "matmul q8k", h, ref, 5e-1, 5e-2)
  205. })
  206. t.Run("q4k", func(t *testing.T) {
  207. q := quant.QuantizeQ4K(Bhost)
  208. devB, err := cuda.UploadQ4K(q, N*(K/256), gpu)
  209. if err != nil {
  210. t.Fatalf("upload: %v", err)
  211. }
  212. defer cuda.FreeDevicePtr(devB)
  213. if err := cuda.MatMulQ4K(Adev.Data().(unsafe.Pointer), devB, Cdev.Data().(unsafe.Pointer), M, K, N, gpu); err != nil {
  214. t.Fatalf("matmul: %v", err)
  215. }
  216. h := make([]float32, M*N)
  217. _ = cuda.Synchronize(gpu)
  218. if err := Cdev.CopyToHost(h); err != nil {
  219. t.Fatalf("copy: %v", err)
  220. }
  221. assertAllClose(t, "matmul q4k", h, ref, 2.0, 1e-1)
  222. })
  223. // For Q2/Q3/Q6, API signatures in cuda.go do not take gpu param.
  224. // Keep them as separate subtests but use gpu=0 tensors.
  225. t.Run("q2k", func(t *testing.T) {
  226. q := quant.QuantizeQ2K(Bhost)
  227. devB, err := cuda.UploadQ2K(q, N*(K/256), gpu)
  228. if err != nil {
  229. t.Fatalf("upload: %v", err)
  230. }
  231. defer cuda.FreeDevicePtr(devB)
  232. if err := cuda.MatMulQ2K(Adev.Data().(unsafe.Pointer), devB, Cdev.Data().(unsafe.Pointer), M, K, N, gpu); err != nil {
  233. t.Fatalf("matmul: %v", err)
  234. }
  235. h := make([]float32, M*N)
  236. _ = cuda.Synchronize(gpu)
  237. if err := Cdev.CopyToHost(h); err != nil {
  238. t.Fatalf("copy: %v", err)
  239. }
  240. assertAllClose(t, "matmul q2k", h, ref, 3.0, 2e-1)
  241. })
  242. t.Run("q3k", func(t *testing.T) {
  243. q := quant.QuantizeQ3K(Bhost)
  244. devB, err := cuda.UploadQ3K(q, N*(K/256), gpu)
  245. if err != nil {
  246. t.Fatalf("upload: %v", err)
  247. }
  248. defer cuda.FreeDevicePtr(devB)
  249. if err := cuda.MatMulQ3K(Adev.Data().(unsafe.Pointer), devB, Cdev.Data().(unsafe.Pointer), M, K, N, gpu); err != nil {
  250. t.Fatalf("matmul: %v", err)
  251. }
  252. h := make([]float32, M*N)
  253. _ = cuda.Synchronize(gpu)
  254. if err := Cdev.CopyToHost(h); err != nil {
  255. t.Fatalf("copy: %v", err)
  256. }
  257. assertAllClose(t, "matmul q3k", h, ref, 2.5, 2e-1)
  258. })
  259. t.Run("q6k", func(t *testing.T) {
  260. q := quant.QuantizeQ6K(Bhost)
  261. devB, err := cuda.UploadQ6K(q, N*(K/256), gpu)
  262. if err != nil {
  263. t.Fatalf("upload: %v", err)
  264. }
  265. defer cuda.FreeDevicePtr(devB)
  266. if err := cuda.MatMulQ6K(Adev.Data().(unsafe.Pointer), devB, Cdev.Data().(unsafe.Pointer), M, K, N, gpu); err != nil {
  267. t.Fatalf("matmul: %v", err)
  268. }
  269. h := make([]float32, M*N)
  270. _ = cuda.Synchronize(gpu)
  271. if err := Cdev.CopyToHost(h); err != nil {
  272. t.Fatalf("copy: %v", err)
  273. }
  274. assertAllClose(t, "matmul q6k", h, ref, 1.0, 1e-1)
  275. })
  276. }
  277. func TestHarness_CUDA_NNOpsMatchCPU(t *testing.T) {
  278. gpu := 0
  279. if !cuda.Available() {
  280. t.Skip("cuda not available")
  281. }
  282. seqLen := 4
  283. headDim := 8
  284. numHeads := 2
  285. numKVHeads := 1
  286. totalDim := numHeads * headDim
  287. r := rand.New(rand.NewSource(2025))
  288. Q := make([]float32, seqLen*totalDim)
  289. K := make([]float32, seqLen*(numKVHeads*headDim))
  290. V := make([]float32, seqLen*(numKVHeads*headDim))
  291. W := make([]float32, totalDim)
  292. for i := range Q {
  293. Q[i] = r.Float32()*2 - 1
  294. }
  295. for i := range K {
  296. K[i] = r.Float32()*2 - 1
  297. }
  298. for i := range V {
  299. V[i] = r.Float32()*2 - 1
  300. }
  301. for i := range W {
  302. W[i] = r.Float32()*2 - 1
  303. }
  304. // CPU reference
  305. qCPU := cpu.NewTensor(tensor.Shape{seqLen, totalDim}, append([]float32(nil), Q...))
  306. kCPU := cpu.NewTensor(tensor.Shape{seqLen, numKVHeads * headDim}, append([]float32(nil), K...))
  307. vCPU := cpu.NewTensor(tensor.Shape{seqLen, numKVHeads * headDim}, append([]float32(nil), V...))
  308. outCPU := cpu.NewTensor(tensor.Shape{seqLen, totalDim}, nil)
  309. if err := cpunn.CausalAttention(qCPU, kCPU, vCPU, outCPU, numHeads, numKVHeads, headDim); err != nil {
  310. t.Fatalf("cpu attention: %v", err)
  311. }
  312. // CUDA tensors
  313. qGPU, _ := cuda.NewTensor(tensor.Shape{seqLen, totalDim}, tensor.Float32, gpu)
  314. kGPU, _ := cuda.NewTensor(tensor.Shape{seqLen, numKVHeads * headDim}, tensor.Float32, gpu)
  315. vGPU, _ := cuda.NewTensor(tensor.Shape{seqLen, numKVHeads * headDim}, tensor.Float32, gpu)
  316. outGPU, _ := cuda.NewTensor(tensor.Shape{seqLen, totalDim}, tensor.Float32, gpu)
  317. wGPU, _ := cuda.NewTensor(tensor.Shape{totalDim}, tensor.Float32, gpu)
  318. qAttGPU, _ := cuda.NewTensor(tensor.Shape{seqLen, totalDim}, tensor.Float32, gpu)
  319. defer qGPU.Free()
  320. defer kGPU.Free()
  321. defer vGPU.Free()
  322. defer outGPU.Free()
  323. defer wGPU.Free()
  324. defer qAttGPU.Free()
  325. _ = qGPU.CopyFrom(Q)
  326. _ = kGPU.CopyFrom(K)
  327. _ = vGPU.CopyFrom(V)
  328. _ = wGPU.CopyFrom(W)
  329. _ = qAttGPU.CopyFrom(Q)
  330. // RMSNorm CPU vs CUDA
  331. // Apply RMSNorm on a copy of Q
  332. qCPU2 := cpu.NewTensor(tensor.Shape{seqLen, totalDim}, append([]float32(nil), Q...))
  333. wCPU2 := cpu.NewTensor(tensor.Shape{totalDim}, append([]float32(nil), W...))
  334. if err := cpunn.RMSNorm(qCPU2, wCPU2, 1e-5); err != nil {
  335. t.Fatalf("cpu rmsnorm: %v", err)
  336. }
  337. if err := cuda.RMSNorm(qGPU.Data().(unsafe.Pointer), wGPU.Data().(unsafe.Pointer), seqLen, totalDim, 1e-5, gpu); err != nil {
  338. t.Fatalf("cuda rmsnorm: %v", err)
  339. }
  340. qR := make([]float32, seqLen*totalDim)
  341. _ = qGPU.CopyToHost(qR)
  342. assertAllClose(t, "rmsnorm", qR, qCPU2.DataFloat32(), 5e-3, 5e-3)
  343. // RoPE CPU vs CUDA
  344. pos := make([]int32, seqLen)
  345. posCPU := make([]int, seqLen)
  346. for i := range pos {
  347. pos[i] = int32(i)
  348. posCPU[i] = i
  349. }
  350. posDev, err := cuda.AllocAndCopyInt32(pos, gpu)
  351. if err != nil {
  352. t.Fatalf("alloc pos: %v", err)
  353. }
  354. defer cuda.FreeDevicePtr(posDev)
  355. qCPU3 := cpu.NewTensor(tensor.Shape{seqLen, totalDim}, append([]float32(nil), Q...))
  356. if err := cpunn.RoPE(qCPU3, posCPU, headDim, 10000); err != nil {
  357. t.Fatalf("cpu rope: %v", err)
  358. }
  359. qGPU2, _ := cuda.NewTensor(tensor.Shape{seqLen, totalDim}, tensor.Float32, gpu)
  360. defer qGPU2.Free()
  361. _ = qGPU2.CopyFrom(Q)
  362. if err := cuda.RoPE(qGPU2.Data().(unsafe.Pointer), posDev, seqLen, numHeads, headDim, 10000, gpu); err != nil {
  363. t.Fatalf("cuda rope: %v", err)
  364. }
  365. qr := make([]float32, seqLen*totalDim)
  366. _ = qGPU2.CopyToHost(qr)
  367. assertAllClose(t, "rope", qr, qCPU3.DataFloat32(), 2e-2, 2e-2)
  368. // Softmax CPU vs CUDA on one row
  369. rowCPU := cpu.NewTensor(tensor.Shape{totalDim}, append([]float32(nil), Q[:totalDim]...))
  370. if err := cpunn.Softmax(rowCPU); err != nil {
  371. t.Fatalf("cpu softmax: %v", err)
  372. }
  373. rowGPU, _ := cuda.NewTensor(tensor.Shape{1, totalDim}, tensor.Float32, gpu)
  374. defer rowGPU.Free()
  375. _ = rowGPU.CopyFrom(Q[:totalDim])
  376. if err := cuda.Softmax(rowGPU.Data().(unsafe.Pointer), 1, totalDim, gpu); err != nil {
  377. t.Fatalf("cuda softmax: %v", err)
  378. }
  379. rowOut := make([]float32, totalDim)
  380. _ = rowGPU.CopyToHost(rowOut)
  381. assertAllClose(t, "softmax", rowOut, rowCPU.DataFloat32(), 2e-3, 2e-3)
  382. // Attention CPU vs CUDA
  383. scale := float32(1.0 / math.Sqrt(float64(headDim)))
  384. if err := cuda.Attention(qAttGPU.Data().(unsafe.Pointer), kGPU.Data().(unsafe.Pointer), vGPU.Data().(unsafe.Pointer), outGPU.Data().(unsafe.Pointer), seqLen, seqLen, numHeads, numKVHeads, headDim, scale, 0, gpu); err != nil {
  385. t.Fatalf("cuda attention: %v", err)
  386. }
  387. outH := make([]float32, seqLen*totalDim)
  388. _ = outGPU.CopyToHost(outH)
  389. assertAllClose(t, "attention", outH, outCPU.DataFloat32(), 5e-2, 5e-2)
  390. }
  391. func TestHarness_CUDA_PagedAttentionBatchMatchesSingle(t *testing.T) {
  392. gpu := 0
  393. if !cuda.Available() {
  394. t.Skip("cuda not available")
  395. }
  396. blockSize := 4
  397. kvLen0 := 5
  398. kvLen1 := 6
  399. headDim := 8
  400. numHeads := 2
  401. numKVHeads := 1
  402. kvStride := numKVHeads * headDim
  403. scale := float32(1.0 / math.Sqrt(float64(headDim)))
  404. // Decode-style: one token per sequence.
  405. numTokens := 2
  406. qGPU, err := cuda.NewTensor(tensor.Shape{numTokens, numHeads * headDim}, tensor.Float32, gpu)
  407. if err != nil {
  408. t.Fatalf("new q: %v", err)
  409. }
  410. defer qGPU.Free()
  411. outBatchGPU, err := cuda.NewTensor(tensor.Shape{numTokens, numHeads * headDim}, tensor.Float32, gpu)
  412. if err != nil {
  413. t.Fatalf("new out batch: %v", err)
  414. }
  415. defer outBatchGPU.Free()
  416. // Per-sequence outputs.
  417. out0GPU, _ := cuda.NewTensor(tensor.Shape{1, numHeads * headDim}, tensor.Float32, gpu)
  418. out1GPU, _ := cuda.NewTensor(tensor.Shape{1, numHeads * headDim}, tensor.Float32, gpu)
  419. defer out0GPU.Free()
  420. defer out1GPU.Free()
  421. r := rand.New(rand.NewSource(777))
  422. qHost := make([]float32, numTokens*numHeads*headDim)
  423. for i := range qHost {
  424. qHost[i] = r.Float32()*2 - 1
  425. }
  426. if err := qGPU.CopyFrom(qHost); err != nil {
  427. t.Fatalf("copy q: %v", err)
  428. }
  429. // Build paged K/V blocks for each sequence.
  430. makeSeqBlocks := func(kvLen int) ([]*cuda.Tensor, []*cuda.Tensor, []uintptr, []uintptr) {
  431. nBlocks := (kvLen + blockSize - 1) / blockSize
  432. kBlocks := make([]*cuda.Tensor, nBlocks)
  433. vBlocks := make([]*cuda.Tensor, nBlocks)
  434. kPtrs := make([]uintptr, nBlocks)
  435. vPtrs := make([]uintptr, nBlocks)
  436. for b := 0; b < nBlocks; b++ {
  437. kT, err := cuda.NewTensor(tensor.Shape{blockSize, kvStride}, tensor.Float32, gpu)
  438. if err != nil {
  439. t.Fatalf("new k block: %v", err)
  440. }
  441. vT, err := cuda.NewTensor(tensor.Shape{blockSize, kvStride}, tensor.Float32, gpu)
  442. if err != nil {
  443. t.Fatalf("new v block: %v", err)
  444. }
  445. kBlocks[b] = kT
  446. vBlocks[b] = vT
  447. kPtrs[b] = uintptr(kT.Data().(unsafe.Pointer))
  448. vPtrs[b] = uintptr(vT.Data().(unsafe.Pointer))
  449. kHost := make([]float32, blockSize*kvStride)
  450. vHost := make([]float32, blockSize*kvStride)
  451. for i := range kHost {
  452. kHost[i] = r.Float32()*2 - 1
  453. vHost[i] = r.Float32()*2 - 1
  454. }
  455. _ = kT.CopyFrom(kHost)
  456. _ = vT.CopyFrom(vHost)
  457. }
  458. return kBlocks, vBlocks, kPtrs, vPtrs
  459. }
  460. kBlocks0, vBlocks0, kPtrs0, vPtrs0 := makeSeqBlocks(kvLen0)
  461. kBlocks1, vBlocks1, kPtrs1, vPtrs1 := makeSeqBlocks(kvLen1)
  462. defer func() {
  463. for i := range kBlocks0 {
  464. kBlocks0[i].Free()
  465. vBlocks0[i].Free()
  466. }
  467. for i := range kBlocks1 {
  468. kBlocks1[i].Free()
  469. vBlocks1[i].Free()
  470. }
  471. }()
  472. // Reference: run single-seq paged attention for each token.
  473. kDev0, err := cuda.AllocAndCopyPtrTable(kPtrs0, gpu)
  474. if err != nil {
  475. t.Fatalf("alloc k ptrs0: %v", err)
  476. }
  477. defer cuda.FreeDevicePtr(kDev0)
  478. vDev0, err := cuda.AllocAndCopyPtrTable(vPtrs0, gpu)
  479. if err != nil {
  480. t.Fatalf("alloc v ptrs0: %v", err)
  481. }
  482. defer cuda.FreeDevicePtr(vDev0)
  483. kDev1, err := cuda.AllocAndCopyPtrTable(kPtrs1, gpu)
  484. if err != nil {
  485. t.Fatalf("alloc k ptrs1: %v", err)
  486. }
  487. defer cuda.FreeDevicePtr(kDev1)
  488. vDev1, err := cuda.AllocAndCopyPtrTable(vPtrs1, gpu)
  489. if err != nil {
  490. t.Fatalf("alloc v ptrs1: %v", err)
  491. }
  492. defer cuda.FreeDevicePtr(vDev1)
  493. q0View, _ := qGPU.ViewAt(tensor.Shape{1, numHeads * headDim}, 0)
  494. q1View, _ := qGPU.ViewAt(tensor.Shape{1, numHeads * headDim}, uintptr(numHeads*headDim*4))
  495. if err := cuda.PagedAttention(
  496. q0View.Data().(unsafe.Pointer),
  497. kDev0, vDev0,
  498. out0GPU.Data().(unsafe.Pointer),
  499. 1, kvLen0,
  500. numHeads, numKVHeads, headDim,
  501. blockSize,
  502. scale, kvLen0-1,
  503. gpu,
  504. ); err != nil {
  505. t.Fatalf("paged attention 0: %v", err)
  506. }
  507. if err := cuda.PagedAttention(
  508. q1View.Data().(unsafe.Pointer),
  509. kDev1, vDev1,
  510. out1GPU.Data().(unsafe.Pointer),
  511. 1, kvLen1,
  512. numHeads, numKVHeads, headDim,
  513. blockSize,
  514. scale, kvLen1-1,
  515. gpu,
  516. ); err != nil {
  517. t.Fatalf("paged attention 1: %v", err)
  518. }
  519. // Batched: flatten block pointer tables.
  520. flatKPtrs := append(append([]uintptr(nil), kPtrs0...), kPtrs1...)
  521. flatVPtrs := append(append([]uintptr(nil), vPtrs0...), vPtrs1...)
  522. kFlatDev, err := cuda.AllocAndCopyPtrTable(flatKPtrs, gpu)
  523. if err != nil {
  524. t.Fatalf("alloc flat k: %v", err)
  525. }
  526. defer cuda.FreeDevicePtr(kFlatDev)
  527. vFlatDev, err := cuda.AllocAndCopyPtrTable(flatVPtrs, gpu)
  528. if err != nil {
  529. t.Fatalf("alloc flat v: %v", err)
  530. }
  531. defer cuda.FreeDevicePtr(vFlatDev)
  532. blockOffsets := []int32{0, int32(len(kPtrs0))}
  533. kvLens := []int32{int32(kvLen0), int32(kvLen1)}
  534. queryPos := []int32{int32(kvLen0 - 1), int32(kvLen1 - 1)}
  535. maxKvLen := kvLen0
  536. if kvLen1 > maxKvLen {
  537. maxKvLen = kvLen1
  538. }
  539. offDev, err := cuda.AllocAndCopyInt32(blockOffsets, gpu)
  540. if err != nil {
  541. t.Fatalf("alloc offsets: %v", err)
  542. }
  543. defer cuda.FreeDevicePtr(offDev)
  544. kvDev, err := cuda.AllocAndCopyInt32(kvLens, gpu)
  545. if err != nil {
  546. t.Fatalf("alloc kv lens: %v", err)
  547. }
  548. defer cuda.FreeDevicePtr(kvDev)
  549. qposDev, err := cuda.AllocAndCopyInt32(queryPos, gpu)
  550. if err != nil {
  551. t.Fatalf("alloc qpos: %v", err)
  552. }
  553. defer cuda.FreeDevicePtr(qposDev)
  554. if err := cuda.PagedAttentionBatch(
  555. qGPU.Data().(unsafe.Pointer),
  556. kFlatDev,
  557. vFlatDev,
  558. offDev,
  559. kvDev,
  560. qposDev,
  561. outBatchGPU.Data().(unsafe.Pointer),
  562. numTokens,
  563. numHeads, numKVHeads, headDim,
  564. blockSize,
  565. scale,
  566. maxKvLen,
  567. gpu,
  568. ); err != nil {
  569. t.Fatalf("paged attention batch: %v", err)
  570. }
  571. outBatchHost := make([]float32, numTokens*numHeads*headDim)
  572. if err := outBatchGPU.CopyToHost(outBatchHost); err != nil {
  573. t.Fatalf("copy out batch: %v", err)
  574. }
  575. out0Host := make([]float32, numHeads*headDim)
  576. out1Host := make([]float32, numHeads*headDim)
  577. _ = out0GPU.CopyToHost(out0Host)
  578. _ = out1GPU.CopyToHost(out1Host)
  579. assertAllClose(t, "paged_attention_batch_tok0", outBatchHost[:numHeads*headDim], out0Host, 2e-2, 2e-2)
  580. assertAllClose(t, "paged_attention_batch_tok1", outBatchHost[numHeads*headDim:], out1Host, 2e-2, 2e-2)
  581. }