1
0

linear_cuda.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. //go:build cuda
  2. // Package compute provides device-agnostic computation dispatching.
  3. package compute
  4. import (
  5. "fmt"
  6. "unsafe"
  7. "makarna/pkg/backend/cpu"
  8. "makarna/pkg/backend/cpu/matmul"
  9. "makarna/pkg/backend/cuda"
  10. "makarna/pkg/backend/device"
  11. "makarna/pkg/profile"
  12. "makarna/pkg/tensor"
  13. )
  14. // Linear performs a linear layer: output = input @ weight.T
  15. // Automatically dispatches to CPU or CUDA based on tensor placement.
  16. // Uses GPU weight cache for persistent weight storage.
  17. func Linear(ctx *Context, input, weight, output tensor.Tensor) error {
  18. useGPU := ctx != nil && ctx.IsGPU() && device.CUDAAvailable()
  19. if !useGPU {
  20. profile.Start("Linear/CPU")
  21. err := linearCPU(input, weight, output)
  22. profile.End("Linear/CPU")
  23. return err
  24. }
  25. switch weight.DType() {
  26. case tensor.Float32:
  27. profile.Start("Linear/F32")
  28. err := linearCUDAF32(ctx, input, weight, output)
  29. profile.End("Linear/F32")
  30. return err
  31. case tensor.Q8_K:
  32. profile.Start("Linear/Q8K")
  33. err := linearCUDAQ8K(ctx, input, weight, output)
  34. profile.End("Linear/Q8K")
  35. return err
  36. case tensor.Q5_K:
  37. profile.Start("Linear/Q5K")
  38. err := linearCUDAQ5K(ctx, input, weight, output)
  39. profile.End("Linear/Q5K")
  40. return err
  41. case tensor.Q4_K:
  42. profile.Start("Linear/Q4K")
  43. err := linearCUDAQ4K(ctx, input, weight, output)
  44. profile.End("Linear/Q4K")
  45. return err
  46. case tensor.Q2_K:
  47. profile.Start("Linear/Q2K")
  48. err := linearCUDAQ2K(ctx, input, weight, output)
  49. profile.End("Linear/Q2K")
  50. return err
  51. case tensor.Q3_K:
  52. profile.Start("Linear/Q3K")
  53. err := linearCUDAQ3K(ctx, input, weight, output)
  54. profile.End("Linear/Q3K")
  55. return err
  56. case tensor.Q6_K:
  57. profile.Start("Linear/Q6K")
  58. err := linearCUDAQ6K(ctx, input, weight, output)
  59. profile.End("Linear/Q6K")
  60. return err
  61. default:
  62. profile.Start("Linear/CPU")
  63. err := linearCPU(input, weight, output)
  64. profile.End("Linear/CPU")
  65. return err
  66. }
  67. }
  68. func linearCPU(input, weight, output tensor.Tensor) error {
  69. inCPU, ok := input.(*cpu.Tensor)
  70. if !ok {
  71. var err error
  72. inCPU, err = ToCPU(input)
  73. if err != nil {
  74. return fmt.Errorf("linear: failed to get CPU input: %w", err)
  75. }
  76. }
  77. wCPU, ok := weight.(*cpu.Tensor)
  78. if !ok {
  79. return fmt.Errorf("linear: weight must be CPU tensor for CPU path")
  80. }
  81. outCPU, ok := output.(*cpu.Tensor)
  82. if !ok {
  83. return fmt.Errorf("linear: output must be CPU tensor for CPU path")
  84. }
  85. return matmul.Linear(inCPU, wCPU, outCPU)
  86. }
  87. // linearCUDAF32 - F32 weights, uses cache
  88. func linearCUDAF32(ctx *Context, input, weight, output tensor.Tensor) error {
  89. inShape := input.Shape()
  90. wShape := weight.Shape()
  91. M, K, N := inShape[0], inShape[1], wShape[0]
  92. gpu := ctx.Placement().GPU
  93. // Get cached GPU input or upload
  94. profile.Start("Linear/F32/input_upload")
  95. gpuInput, err := getOrUploadInput(input, gpu)
  96. profile.End("Linear/F32/input_upload")
  97. if err != nil {
  98. return err
  99. }
  100. // Get cached weight
  101. cache := GetWeightCache(gpu)
  102. weightKey := fmt.Sprintf("layer%d_w_f16_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
  103. gpuWeight, ok := cache.Get(weightKey)
  104. if !ok {
  105. profile.Start("Linear/F32/weight_upload")
  106. cpuW := weight.(*cpu.Tensor)
  107. gpuWeight, err = cache.UploadF16(weightKey, cpuW)
  108. profile.End("Linear/F32/weight_upload")
  109. if err != nil {
  110. return fmt.Errorf("linear F32: cache weight: %w", err)
  111. }
  112. }
  113. // Allocate output
  114. gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
  115. if err != nil {
  116. return fmt.Errorf("linear F32: alloc output: %w", err)
  117. }
  118. defer gpuOutput.Free()
  119. // Execute matmul using raw pointers
  120. profile.Start("Linear/F32/matmul_kernel")
  121. aPtr := gpuInput.Data().(unsafe.Pointer)
  122. cPtr := gpuOutput.Data().(unsafe.Pointer)
  123. if gpuInput.DType() != tensor.Float16 {
  124. profile.Start("Linear/F32/cast_fp16")
  125. f16In, err := cuda.NewTensor(tensor.Shape{M, K}, tensor.Float16, gpu)
  126. if err != nil {
  127. profile.End("Linear/F32/cast_fp16")
  128. profile.End("Linear/F32/matmul_kernel")
  129. return fmt.Errorf("linear F32: alloc f16 input: %w", err)
  130. }
  131. defer f16In.Free()
  132. if err := cuda.CastF32ToF16(aPtr, f16In.Data().(unsafe.Pointer), M*K, gpu); err != nil {
  133. profile.End("Linear/F32/cast_fp16")
  134. profile.End("Linear/F32/matmul_kernel")
  135. return fmt.Errorf("linear F32: cast input f32->f16: %w", err)
  136. }
  137. aPtr = f16In.Data().(unsafe.Pointer)
  138. profile.End("Linear/F32/cast_fp16")
  139. }
  140. err = cuda.MatMulF16(aPtr, gpuWeight, cPtr, M, K, N, gpu)
  141. profile.End("Linear/F32/matmul_kernel")
  142. if err != nil {
  143. return fmt.Errorf("linear F32: matmul f16: %w", err)
  144. }
  145. // Copy back to CPU output
  146. if cpuOut, ok := output.(*cpu.Tensor); ok {
  147. if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
  148. return fmt.Errorf("linear F32: copy D2H: %w", err)
  149. }
  150. }
  151. return nil
  152. }
  153. func linearCUDAQ5K(ctx *Context, input, weight, output tensor.Tensor) error {
  154. inShape := input.Shape()
  155. wShape := weight.Shape()
  156. M, K, N := inShape[0], inShape[1], wShape[0]
  157. gpu := ctx.Placement().GPU
  158. profile.Start("Linear/Q5K/input_upload")
  159. gpuInput, err := getOrUploadInput(input, gpu)
  160. profile.End("Linear/Q5K/input_upload")
  161. if err != nil {
  162. return err
  163. }
  164. cache := GetWeightCache(gpu)
  165. weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
  166. gpuWeight, ok := cache.Get(weightKey)
  167. if !ok {
  168. profile.Start("Linear/Q5K/weight_upload")
  169. cpuW := weight.(*cpu.Tensor)
  170. gpuWeight, err = cache.Upload(weightKey, cpuW)
  171. profile.End("Linear/Q5K/weight_upload")
  172. if err != nil {
  173. return fmt.Errorf("linear Q5K: cache weight: %w", err)
  174. }
  175. }
  176. gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
  177. if err != nil {
  178. return fmt.Errorf("linear Q5K: alloc output: %w", err)
  179. }
  180. defer gpuOutput.Free()
  181. aPtr := gpuInput.Data().(unsafe.Pointer)
  182. cPtr := gpuOutput.Data().(unsafe.Pointer)
  183. profile.Start("Linear/Q5K/matmul_kernel")
  184. err = cuda.MatMulQ5K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
  185. profile.End("Linear/Q5K/matmul_kernel")
  186. if err != nil {
  187. return fmt.Errorf("linear Q5K: matmul: %w", err)
  188. }
  189. if cpuOut, ok := output.(*cpu.Tensor); ok {
  190. if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
  191. return fmt.Errorf("linear Q5K: copy D2H: %w", err)
  192. }
  193. }
  194. return nil
  195. }
  196. // linearCUDAQ8K - Q8_K weights with caching
  197. func linearCUDAQ8K(ctx *Context, input, weight, output tensor.Tensor) error {
  198. inShape := input.Shape()
  199. wShape := weight.Shape()
  200. M, K, N := inShape[0], inShape[1], wShape[0]
  201. gpu := ctx.Placement().GPU
  202. // Get GPU input
  203. profile.Start("Linear/Q8K/input_upload")
  204. gpuInput, err := getOrUploadInput(input, gpu)
  205. profile.End("Linear/Q8K/input_upload")
  206. if err != nil {
  207. return err
  208. }
  209. // Get cached weight or upload
  210. cache := GetWeightCache(gpu)
  211. weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
  212. gpuWeight, ok := cache.Get(weightKey)
  213. if !ok {
  214. profile.Start("Linear/Q8K/weight_upload")
  215. cpuW := weight.(*cpu.Tensor)
  216. gpuWeight, err = cache.Upload(weightKey, cpuW)
  217. profile.End("Linear/Q8K/weight_upload")
  218. if err != nil {
  219. return fmt.Errorf("linear Q8K: cache weight: %w", err)
  220. }
  221. }
  222. // Allocate output
  223. gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
  224. if err != nil {
  225. return fmt.Errorf("linear Q8K: alloc output: %w", err)
  226. }
  227. defer gpuOutput.Free()
  228. // Execute fused matmul
  229. aPtr := gpuInput.Data().(unsafe.Pointer)
  230. cPtr := gpuOutput.Data().(unsafe.Pointer)
  231. profile.Start("Linear/Q8K/matmul_kernel")
  232. err = cuda.MatMulQ8K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
  233. profile.End("Linear/Q8K/matmul_kernel")
  234. if err != nil {
  235. return fmt.Errorf("linear Q8K: matmul: %w", err)
  236. }
  237. // Copy back
  238. if cpuOut, ok := output.(*cpu.Tensor); ok {
  239. if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
  240. return fmt.Errorf("linear Q8K: copy D2H: %w", err)
  241. }
  242. }
  243. return nil
  244. }
  245. // linearCUDAQ4K - Q4_K weights with caching
  246. func linearCUDAQ4K(ctx *Context, input, weight, output tensor.Tensor) error {
  247. inShape := input.Shape()
  248. wShape := weight.Shape()
  249. M, K, N := inShape[0], inShape[1], wShape[0]
  250. gpu := ctx.Placement().GPU
  251. // Get GPU input
  252. profile.Start("Linear/Q4K/input_upload")
  253. gpuInput, err := getOrUploadInput(input, gpu)
  254. profile.End("Linear/Q4K/input_upload")
  255. if err != nil {
  256. return err
  257. }
  258. // Get cached weight
  259. cache := GetWeightCache(gpu)
  260. weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
  261. gpuWeight, ok := cache.Get(weightKey)
  262. if !ok {
  263. profile.Start("Linear/Q4K/weight_upload")
  264. cpuW := weight.(*cpu.Tensor)
  265. gpuWeight, err = cache.Upload(weightKey, cpuW)
  266. profile.End("Linear/Q4K/weight_upload")
  267. if err != nil {
  268. return fmt.Errorf("linear Q4K: cache weight: %w", err)
  269. }
  270. }
  271. // Allocate output
  272. gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
  273. if err != nil {
  274. return fmt.Errorf("linear Q4K: alloc output: %w", err)
  275. }
  276. defer gpuOutput.Free()
  277. // Execute fused matmul
  278. aPtr := gpuInput.Data().(unsafe.Pointer)
  279. cPtr := gpuOutput.Data().(unsafe.Pointer)
  280. profile.Start("Linear/Q4K/matmul_kernel")
  281. err = cuda.MatMulQ4K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
  282. profile.End("Linear/Q4K/matmul_kernel")
  283. if err != nil {
  284. return fmt.Errorf("linear Q4K: matmul: %w", err)
  285. }
  286. // Copy back
  287. if cpuOut, ok := output.(*cpu.Tensor); ok {
  288. if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
  289. return fmt.Errorf("linear Q4K: copy D2H: %w", err)
  290. }
  291. }
  292. return nil
  293. }
  294. // linearCUDAQ2K - Q2_K weights with caching
  295. func linearCUDAQ2K(ctx *Context, input, weight, output tensor.Tensor) error {
  296. inShape := input.Shape()
  297. wShape := weight.Shape()
  298. M, K, N := inShape[0], inShape[1], wShape[0]
  299. gpu := ctx.Placement().GPU
  300. // Get GPU input
  301. profile.Start("Linear/Q2K/input_upload")
  302. gpuInput, err := getOrUploadInput(input, gpu)
  303. profile.End("Linear/Q2K/input_upload")
  304. if err != nil {
  305. return err
  306. }
  307. // Get cached weight
  308. cache := GetWeightCache(gpu)
  309. weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
  310. gpuWeight, ok := cache.Get(weightKey)
  311. if !ok {
  312. profile.Start("Linear/Q2K/weight_upload")
  313. cpuW := weight.(*cpu.Tensor)
  314. gpuWeight, err = cache.Upload(weightKey, cpuW)
  315. profile.End("Linear/Q2K/weight_upload")
  316. if err != nil {
  317. return fmt.Errorf("linear Q2K: cache weight: %w", err)
  318. }
  319. }
  320. // Allocate output
  321. gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
  322. if err != nil {
  323. return fmt.Errorf("linear Q2K: alloc output: %w", err)
  324. }
  325. defer gpuOutput.Free()
  326. // Execute fused matmul
  327. profile.Start("Linear/Q2K/matmul_kernel")
  328. aPtr := gpuInput.Data().(unsafe.Pointer)
  329. cPtr := gpuOutput.Data().(unsafe.Pointer)
  330. err = cuda.MatMulQ2K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
  331. profile.End("Linear/Q2K/matmul_kernel")
  332. if err != nil {
  333. return fmt.Errorf("linear Q2K: matmul: %w", err)
  334. }
  335. // Copy back
  336. if cpuOut, ok := output.(*cpu.Tensor); ok {
  337. if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
  338. return fmt.Errorf("linear Q2K: copy D2H: %w", err)
  339. }
  340. }
  341. return nil
  342. }
  343. // linearCUDAQ3K - Q3_K weights with caching
  344. func linearCUDAQ3K(ctx *Context, input, weight, output tensor.Tensor) error {
  345. inShape := input.Shape()
  346. wShape := weight.Shape()
  347. M, K, N := inShape[0], inShape[1], wShape[0]
  348. gpu := ctx.Placement().GPU
  349. // Get GPU input
  350. profile.Start("Linear/Q3K/input_upload")
  351. gpuInput, err := getOrUploadInput(input, gpu)
  352. profile.End("Linear/Q3K/input_upload")
  353. if err != nil {
  354. return err
  355. }
  356. // Get cached weight
  357. cache := GetWeightCache(gpu)
  358. weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
  359. gpuWeight, ok := cache.Get(weightKey)
  360. if !ok {
  361. profile.Start("Linear/Q3K/weight_upload")
  362. cpuW := weight.(*cpu.Tensor)
  363. gpuWeight, err = cache.Upload(weightKey, cpuW)
  364. profile.End("Linear/Q3K/weight_upload")
  365. if err != nil {
  366. return fmt.Errorf("linear Q3K: cache weight: %w", err)
  367. }
  368. }
  369. // Allocate output
  370. gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
  371. if err != nil {
  372. return fmt.Errorf("linear Q3K: alloc output: %w", err)
  373. }
  374. defer gpuOutput.Free()
  375. // Execute fused matmul
  376. profile.Start("Linear/Q3K/matmul_kernel")
  377. aPtr := gpuInput.Data().(unsafe.Pointer)
  378. cPtr := gpuOutput.Data().(unsafe.Pointer)
  379. err = cuda.MatMulQ3K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
  380. profile.End("Linear/Q3K/matmul_kernel")
  381. if err != nil {
  382. return fmt.Errorf("linear Q3K: matmul: %w", err)
  383. }
  384. // Copy back
  385. if cpuOut, ok := output.(*cpu.Tensor); ok {
  386. if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
  387. return fmt.Errorf("linear Q3K: copy D2H: %w", err)
  388. }
  389. }
  390. return nil
  391. }
  392. // linearCUDAQ6K - Q6_K weights with caching
  393. func linearCUDAQ6K(ctx *Context, input, weight, output tensor.Tensor) error {
  394. inShape := input.Shape()
  395. wShape := weight.Shape()
  396. M, K, N := inShape[0], inShape[1], wShape[0]
  397. gpu := ctx.Placement().GPU
  398. // Get GPU input
  399. profile.Start("Linear/Q6K/input_upload")
  400. gpuInput, err := getOrUploadInput(input, gpu)
  401. profile.End("Linear/Q6K/input_upload")
  402. if err != nil {
  403. return err
  404. }
  405. // Get cached weight
  406. cache := GetWeightCache(gpu)
  407. weightKey := fmt.Sprintf("layer%d_w_%p", ctx.LayerIdx, weight.(*cpu.Tensor))
  408. gpuWeight, ok := cache.Get(weightKey)
  409. if !ok {
  410. profile.Start("Linear/Q6K/weight_upload")
  411. cpuW := weight.(*cpu.Tensor)
  412. gpuWeight, err = cache.Upload(weightKey, cpuW)
  413. profile.End("Linear/Q6K/weight_upload")
  414. if err != nil {
  415. return fmt.Errorf("linear Q6K: cache weight: %w", err)
  416. }
  417. }
  418. // Allocate output
  419. gpuOutput, err := cuda.NewTensor(tensor.Shape{M, N}, tensor.Float32, gpu)
  420. if err != nil {
  421. return fmt.Errorf("linear Q6K: alloc output: %w", err)
  422. }
  423. defer gpuOutput.Free()
  424. // Execute fused matmul
  425. profile.Start("Linear/Q6K/matmul_kernel")
  426. aPtr := gpuInput.Data().(unsafe.Pointer)
  427. cPtr := gpuOutput.Data().(unsafe.Pointer)
  428. err = cuda.MatMulQ6K(aPtr, gpuWeight, cPtr, M, K, N, gpu)
  429. profile.End("Linear/Q6K/matmul_kernel")
  430. if err != nil {
  431. return fmt.Errorf("linear Q6K: matmul: %w", err)
  432. }
  433. // Copy back
  434. if cpuOut, ok := output.(*cpu.Tensor); ok {
  435. if err := gpuOutput.CopyToHost(cpuOut.DataFloat32()); err != nil {
  436. return fmt.Errorf("linear Q6K: copy D2H: %w", err)
  437. }
  438. }
  439. return nil
  440. }
  441. // getOrUploadInput uploads CPU input to GPU
  442. func getOrUploadInput(input tensor.Tensor, gpu int) (*cuda.Tensor, error) {
  443. if cudaIn, ok := input.(*cuda.Tensor); ok {
  444. return cudaIn, nil
  445. }
  446. cpuIn, ok := input.(*cpu.Tensor)
  447. if !ok {
  448. return nil, fmt.Errorf("input must be CPU or CUDA tensor")
  449. }
  450. shape := input.Shape()
  451. gpuInput, err := cuda.NewTensor(shape, tensor.Float32, gpu)
  452. if err != nil {
  453. return nil, fmt.Errorf("alloc GPU input: %w", err)
  454. }
  455. if err := gpuInput.CopyFrom(cpuIn.DataFloat32()); err != nil {
  456. return nil, fmt.Errorf("copy input H2D: %w", err)
  457. }
  458. return gpuInput, nil
  459. }