1
0

cuda.go 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566
  1. //go:build cuda
  2. package cuda
  3. /*
  4. #cgo CFLAGS: -I${SRCDIR}
  5. #cgo LDFLAGS: -L${SRCDIR}/../../..//build/cuda -Wl,-Bstatic -lmakarna_cuda -Wl,-Bdynamic
  6. #cgo LDFLAGS: -L/usr/local/cuda/lib64 -lcudart -lstdc++ -lm
  7. #cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../..//build/cuda -Wl,-rpath,/usr/local/cuda/lib64
  8. #include "kernels.h"
  9. */
  10. import "C"
  11. import (
  12. "errors"
  13. "fmt"
  14. "runtime"
  15. "time"
  16. "unsafe"
  17. "makarna/pkg/profile"
  18. "makarna/pkg/tensor"
  19. )
  20. func syncIfProfiling(gpu int) error {
  21. if !profile.Enabled() {
  22. return nil
  23. }
  24. return Synchronize(gpu)
  25. }
  26. // Ensure Interface Compliance
  27. var _ tensor.Tensor = (*Tensor)(nil)
  28. // Storage holds the underlying GPU memory with reference counting.
  29. // Multiple Tensors can share the same Storage (e.g., views, reshapes).
  30. // Memory is freed only when all references are gone.
  31. type Storage struct {
  32. ptr unsafe.Pointer
  33. gpu int
  34. // Note: We rely on Go's GC and SetFinalizer for ref counting.
  35. // Each Tensor that shares this storage keeps a reference to it.
  36. // When the last Tensor is GC'd, the Storage becomes unreachable,
  37. // and its finalizer frees the GPU memory.
  38. }
  39. // newStorage creates a new Storage and sets up its finalizer
  40. func newStorage(ptr unsafe.Pointer, gpu int) *Storage {
  41. s := &Storage{ptr: ptr, gpu: gpu}
  42. runtime.SetFinalizer(s, func(st *Storage) {
  43. _ = C.cuda_set_device(C.int(st.gpu))
  44. C.cuda_free(st.ptr)
  45. })
  46. return s
  47. }
  48. type Tensor struct {
  49. shape tensor.Shape
  50. dtype tensor.DType
  51. storage *Storage // Shared storage with ref counting
  52. ptr unsafe.Pointer // Pointer into storage (may be offset for slices)
  53. gpu int
  54. // ownsStorage indicates whether this Tensor is responsible for explicitly
  55. // freeing the underlying CUDA allocation.
  56. // Views/reshapes must not free shared storage because they may outlive the base
  57. // tensor (e.g. scratch-buffer views).
  58. ownsStorage bool
  59. }
  60. // NewTensor allocates memory on the GPU
  61. func NewTensor(shape tensor.Shape, dtype tensor.DType, gpu int) (*Tensor, error) {
  62. if dtype != tensor.Float32 && dtype != tensor.Float16 && dtype != tensor.BFloat16 {
  63. return nil, errors.New("unsupported dtype on CUDA")
  64. }
  65. if gpu < 0 {
  66. gpu = 0
  67. }
  68. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  69. return nil, errors.New("failed to set cuda device")
  70. }
  71. size := shape.NumElements() * dtype.Size()
  72. ptr := C.cuda_malloc(C.size_t(size))
  73. if ptr == nil {
  74. return nil, errors.New("cuda malloc failed")
  75. }
  76. storage := newStorage(ptr, gpu)
  77. t := &Tensor{
  78. shape: shape,
  79. dtype: dtype,
  80. storage: storage,
  81. ptr: ptr,
  82. gpu: gpu,
  83. ownsStorage: true,
  84. }
  85. return t, nil
  86. }
  87. func (t *Tensor) Shape() tensor.Shape {
  88. return t.shape
  89. }
  90. func (t *Tensor) DType() tensor.DType {
  91. return t.dtype
  92. }
  93. func (t *Tensor) Device() tensor.DeviceType {
  94. return tensor.CUDA
  95. }
  96. // GPU returns the device ordinal.
  97. func (t *Tensor) GPU() int {
  98. return t.gpu
  99. }
  100. func (t *Tensor) Placement() tensor.DevicePlacement {
  101. return tensor.DevicePlacement{Type: tensor.CUDA, GPU: t.gpu}
  102. }
  103. func (t *Tensor) Data() interface{} {
  104. return t.ptr
  105. }
  106. // Free explicitly frees the GPU memory associated with the tensor.
  107. // Use this for temporary tensors to avoid OOM due to delayed GC.
  108. func (t *Tensor) Free() {
  109. if t == nil {
  110. return
  111. }
  112. // Only the allocating tensor should explicitly free the CUDA allocation.
  113. // Views/reshapes share storage and must not free it.
  114. if t.storage != nil && t.ownsStorage {
  115. // Clear finalizer so it doesn't run later
  116. runtime.SetFinalizer(t.storage, nil)
  117. _ = C.cuda_set_device(C.int(t.gpu))
  118. C.cuda_free(t.storage.ptr)
  119. }
  120. t.storage = nil
  121. t.ptr = nil
  122. }
  123. func (t *Tensor) Add(other tensor.Tensor) error {
  124. o, ok := other.(*Tensor)
  125. if !ok {
  126. return errors.New("other must be CUDA tensor")
  127. }
  128. if t.dtype != tensor.Float32 || o.dtype != tensor.Float32 {
  129. return errors.New("Add only supports Float32")
  130. }
  131. if t.shape.NumElements() != o.shape.NumElements() {
  132. return errors.New("shape mismatch")
  133. }
  134. // Calls in-place add: t += o
  135. if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
  136. return errors.New("failed to set cuda device")
  137. }
  138. ret := C.cuda_add_f32((*C.float)(t.ptr), (*C.float)(o.ptr), C.size_t(t.shape.NumElements()))
  139. if ret != 0 {
  140. return errors.New("cuda add failed")
  141. }
  142. return nil
  143. }
  144. func PagedAttentionBatch(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, gpu int) error {
  145. if err := syncIfProfiling(gpu); err != nil {
  146. return err
  147. }
  148. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  149. return errors.New("failed to set cuda device")
  150. }
  151. ret := C.cuda_paged_attention_batch_f32(
  152. (*C.float)(Q),
  153. (**C.float)(kBlocksFlatDev),
  154. (**C.float)(vBlocksFlatDev),
  155. (*C.int)(blockOffsetsDev),
  156. (*C.int)(kvLensDev),
  157. (*C.int)(queryPosDev),
  158. (*C.float)(out),
  159. C.int(numTokens),
  160. C.int(numHeads), C.int(numKVHeads), C.int(headDim),
  161. C.int(blockSize),
  162. C.float(scale),
  163. C.int(maxKvLen),
  164. )
  165. if ret != 0 {
  166. return errors.New("cuda paged attention batch failed")
  167. }
  168. if err := syncIfProfiling(gpu); err != nil {
  169. return err
  170. }
  171. return nil
  172. }
  173. func (t *Tensor) Mul(other tensor.Tensor) error {
  174. o, ok := other.(*Tensor)
  175. if !ok {
  176. return errors.New("other must be CUDA tensor")
  177. }
  178. if t.dtype != tensor.Float32 || o.dtype != tensor.Float32 {
  179. return errors.New("Mul only supports Float32")
  180. }
  181. if t.shape.NumElements() != o.shape.NumElements() {
  182. return errors.New("shape mismatch")
  183. }
  184. if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
  185. return errors.New("failed to set cuda device")
  186. }
  187. ret := C.cuda_mul_f32((*C.float)(t.ptr), (*C.float)(o.ptr), C.size_t(t.shape.NumElements()))
  188. if ret != 0 {
  189. return errors.New("cuda mul failed")
  190. }
  191. return nil
  192. }
  193. func (t *Tensor) MatMul(other tensor.Tensor, out tensor.Tensor) error {
  194. B, ok := other.(*Tensor)
  195. if !ok {
  196. return errors.New("other must be CUDA tensor")
  197. }
  198. C_out, ok := out.(*Tensor)
  199. if !ok {
  200. return errors.New("out must be CUDA tensor")
  201. }
  202. if t.dtype != tensor.Float32 || B.dtype != tensor.Float32 || C_out.dtype != tensor.Float32 {
  203. return errors.New("MatMul only supports Float32")
  204. }
  205. if len(t.shape) != 2 || len(B.shape) != 2 || len(C_out.shape) != 2 {
  206. return errors.New("only 2D matmul")
  207. }
  208. M := t.shape[0]
  209. K := t.shape[1]
  210. // We use NT matmul (A @ B^T), so B is expected to be [N, K]
  211. N := B.shape[0]
  212. K2 := B.shape[1]
  213. if K != K2 {
  214. return fmt.Errorf("k dimension mismatch: A[%d,%d] vs B[%d,%d]", M, K, N, K2)
  215. }
  216. if C_out.shape[0] != M || C_out.shape[1] != N {
  217. return fmt.Errorf("out shape mismatch: expected [%d,%d], got [%d,%d]", M, N, C_out.shape[0], C_out.shape[1])
  218. }
  219. if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
  220. return errors.New("failed to set cuda device")
  221. }
  222. ret := C.cuda_matmul_f32_nt(
  223. (*C.float)(t.ptr),
  224. (*C.float)(B.ptr),
  225. (*C.float)(C_out.ptr),
  226. C.int(M), C.int(K), C.int(N),
  227. )
  228. if ret != 0 {
  229. return errors.New("cuda matmul failed")
  230. }
  231. return nil
  232. }
  233. // Reshape creates a view (shared storage) with new shape.
  234. // The new tensor shares the same underlying Storage, so memory
  235. // is only freed when all tensors sharing this storage are GC'd.
  236. func (t *Tensor) Reshape(shape tensor.Shape) (tensor.Tensor, error) {
  237. if shape.NumElements() != t.shape.NumElements() {
  238. return nil, errors.New("num elements mismatch")
  239. }
  240. // Share the same storage - Go's GC handles ref counting for us
  241. return &Tensor{
  242. shape: shape,
  243. dtype: t.dtype,
  244. storage: t.storage, // Shared reference
  245. ptr: t.ptr,
  246. gpu: t.gpu,
  247. ownsStorage: false,
  248. }, nil
  249. }
  250. // ViewAt returns a view into the tensor starting at the given byte offset.
  251. // The returned tensor shares storage and does not allocate.
  252. func (t *Tensor) ViewAt(shape tensor.Shape, offsetBytes uintptr) (*Tensor, error) {
  253. if t == nil {
  254. return nil, errors.New("nil tensor")
  255. }
  256. if offsetBytes%uintptr(t.dtype.Size()) != 0 {
  257. return nil, fmt.Errorf("offset %d not aligned to dtype size %d", offsetBytes, t.dtype.Size())
  258. }
  259. newPtr := unsafe.Pointer(uintptr(t.ptr) + offsetBytes)
  260. return &Tensor{
  261. shape: shape,
  262. dtype: t.dtype,
  263. storage: t.storage,
  264. ptr: newPtr,
  265. gpu: t.gpu,
  266. ownsStorage: false,
  267. }, nil
  268. }
  269. func (t *Tensor) View(shape tensor.Shape) (tensor.Tensor, error) {
  270. return t.Reshape(shape)
  271. }
  272. func (t *Tensor) ToDevice(device tensor.DeviceType) (tensor.Tensor, error) {
  273. if device == tensor.CUDA {
  274. return t, nil
  275. }
  276. // TODO: support CUDA -> CPU
  277. if device == tensor.CPU {
  278. // We need to copy data back
  279. // 1. Create CPU tensor
  280. // 2. Memcpy D2H
  281. // 3. Return CPU tensor
  282. // This requires importing "makarna/pkg/backend/cpu". Circular dependency risk?
  283. // No, `cpu` imports `tensor`, `cuda` imports `tensor`.
  284. // But `cuda` cannot import `cpu` easily if `cpu` is intended to be the default.
  285. // Actually it's fine if `cuda` imports `cpu`.
  286. return nil, errors.New("ToDevice(CPU) not implemented here yet, use helper")
  287. }
  288. return nil, errors.New("unknown device")
  289. }
  290. func (t *Tensor) CopyFrom(data interface{}) error {
  291. if t.dtype != tensor.Float32 {
  292. return errors.New("CopyFrom only supports Float32")
  293. }
  294. // Assuming data is []float32 on Host
  295. src, ok := data.([]float32)
  296. if !ok {
  297. return errors.New("data must be []float32")
  298. }
  299. size := len(src) * 4
  300. if size != t.shape.NumElements()*t.dtype.Size() {
  301. return errors.New("size mismatch")
  302. }
  303. if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
  304. return errors.New("failed to set cuda device")
  305. }
  306. start := time.Now()
  307. ret := C.cuda_memcpy_h2d(t.ptr, unsafe.Pointer(&src[0]), C.size_t(size))
  308. if ret != 0 {
  309. runtime.KeepAlive(src)
  310. runtime.KeepAlive(t)
  311. return errors.New("cuda memcpy failed")
  312. }
  313. profile.RecordTransfer("CopyFrom/H2D", profile.EventH2D, int64(size), time.Since(start), t.gpu)
  314. runtime.KeepAlive(src)
  315. runtime.KeepAlive(t)
  316. return nil
  317. }
  318. // Helper to copy back to host
  319. func (t *Tensor) CopyToHost(dst []float32) error {
  320. if t.dtype != tensor.Float32 {
  321. return errors.New("CopyToHost only supports Float32")
  322. }
  323. size := len(dst) * 4
  324. if size != t.shape.NumElements()*4 {
  325. return errors.New("size mismatch")
  326. }
  327. if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
  328. return errors.New("failed to set cuda device")
  329. }
  330. start := time.Now()
  331. ret := C.cuda_memcpy_d2h(unsafe.Pointer(&dst[0]), t.ptr, C.size_t(size))
  332. if ret != 0 {
  333. runtime.KeepAlive(dst)
  334. runtime.KeepAlive(t)
  335. return errors.New("cuda memcpy d2h failed")
  336. }
  337. profile.RecordTransfer("CopyToHost/D2H", profile.EventD2H, int64(size), time.Since(start), t.gpu)
  338. runtime.KeepAlive(dst)
  339. runtime.KeepAlive(t)
  340. return nil
  341. }
  342. func (t *Tensor) CopyToInt32(dst []int32) error {
  343. if t.dtype != tensor.Int32 {
  344. return errors.New("CopyToInt32 only supports Int32")
  345. }
  346. size := len(dst) * 4
  347. if size != t.shape.NumElements()*4 {
  348. return errors.New("size mismatch")
  349. }
  350. if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
  351. return errors.New("failed to set cuda device")
  352. }
  353. ret := C.cuda_memcpy_d2h(unsafe.Pointer(&dst[0]), t.ptr, C.size_t(size))
  354. if ret != 0 {
  355. return errors.New("cuda memcpy d2h failed")
  356. }
  357. runtime.KeepAlive(dst)
  358. runtime.KeepAlive(t)
  359. return nil
  360. }
  361. // CopyPartialFrom copies a portion of host data to the tensor at a given offset.
  362. // dstOffset: offset in float32 elements from the start of the tensor
  363. // src: source data to copy from host
  364. func (t *Tensor) CopyPartialFrom(dstOffset int, src []float32) error {
  365. if t.dtype != tensor.Float32 {
  366. return errors.New("CopyPartialFrom only supports Float32")
  367. }
  368. if dstOffset+len(src) > t.shape.NumElements() {
  369. return errors.New("partial copy would exceed tensor bounds")
  370. }
  371. if len(src) == 0 {
  372. return nil
  373. }
  374. if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
  375. return errors.New("failed to set cuda device")
  376. }
  377. // Calculate destination pointer with offset
  378. dstPtr := unsafe.Pointer(uintptr(t.ptr) + uintptr(dstOffset*4))
  379. size := len(src) * 4
  380. start := time.Now()
  381. ret := C.cuda_memcpy_h2d(dstPtr, unsafe.Pointer(&src[0]), C.size_t(size))
  382. if ret != 0 {
  383. runtime.KeepAlive(src)
  384. runtime.KeepAlive(t)
  385. return errors.New("cuda memcpy partial failed")
  386. }
  387. profile.RecordTransfer("CopyPartialFrom/H2D", profile.EventH2D, int64(size), time.Since(start), t.gpu)
  388. runtime.KeepAlive(src)
  389. runtime.KeepAlive(t)
  390. return nil
  391. }
  392. // CopyPartialFromDevice copies a portion from another CUDA tensor into this tensor.
  393. // Offsets and length are in float32 elements.
  394. func (t *Tensor) CopyPartialFromDevice(dstOffset int, src *Tensor, srcOffset int, length int) error {
  395. if t.dtype != src.dtype {
  396. return errors.New("dtype mismatch")
  397. }
  398. if dstOffset+length > t.shape.NumElements() {
  399. return errors.New("dst offset/length exceed tensor bounds")
  400. }
  401. if srcOffset+length > src.shape.NumElements() {
  402. return errors.New("src offset/length exceed tensor bounds")
  403. }
  404. if length == 0 {
  405. return nil
  406. }
  407. if ret := C.cuda_set_device(C.int(t.gpu)); ret != 0 {
  408. return errors.New("failed to set cuda device")
  409. }
  410. start := time.Now()
  411. eltSize := t.dtype.Size()
  412. dstPtr := unsafe.Pointer(uintptr(t.ptr) + uintptr(dstOffset*eltSize))
  413. srcPtr := unsafe.Pointer(uintptr(src.ptr) + uintptr(srcOffset*eltSize))
  414. size := C.size_t(length * eltSize)
  415. if ret := C.cuda_memcpy_d2d(dstPtr, srcPtr, size); ret != 0 {
  416. runtime.KeepAlive(src)
  417. runtime.KeepAlive(t)
  418. return errors.New("cuda memcpy d2d failed")
  419. }
  420. profile.RecordTransfer("CopyPartialFromDevice/D2D", profile.EventD2D, int64(length*eltSize), time.Since(start), t.gpu)
  421. runtime.KeepAlive(src)
  422. runtime.KeepAlive(t)
  423. return nil
  424. }
  425. func CastF32ToF16(srcF32, dstF16 unsafe.Pointer, n int, gpu int) error {
  426. if n <= 0 {
  427. return nil
  428. }
  429. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  430. return errors.New("failed to set cuda device")
  431. }
  432. if ret := C.cuda_cast_f32_to_f16((*C.float)(srcF32), (*C.ushort)(dstF16), C.int(n)); ret != 0 {
  433. return errors.New("cuda cast f32->f16 failed")
  434. }
  435. return nil
  436. }
  437. func PagedAttentionF32F16KV(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, gpu int) error {
  438. if err := syncIfProfiling(gpu); err != nil {
  439. return err
  440. }
  441. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  442. return errors.New("failed to set cuda device")
  443. }
  444. ret := C.cuda_paged_attention_f32_f16kv(
  445. (*C.float)(Q),
  446. (**C.ushort)(kBlocksDev),
  447. (**C.ushort)(vBlocksDev),
  448. (*C.float)(out),
  449. C.int(seqLen), C.int(kvLen),
  450. C.int(numHeads), C.int(numKVHeads), C.int(headDim),
  451. C.int(blockSize),
  452. C.float(scale), C.int(startPos),
  453. )
  454. if ret != 0 {
  455. return errors.New("cuda paged attention f16kv failed")
  456. }
  457. if err := syncIfProfiling(gpu); err != nil {
  458. return err
  459. }
  460. return nil
  461. }
  462. func PagedAttentionBatchF32F16KV(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, gpu int) error {
  463. if err := syncIfProfiling(gpu); err != nil {
  464. return err
  465. }
  466. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  467. return errors.New("failed to set cuda device")
  468. }
  469. ret := C.cuda_paged_attention_batch_f32_f16kv(
  470. (*C.float)(Q),
  471. (**C.ushort)(kBlocksFlatDev),
  472. (**C.ushort)(vBlocksFlatDev),
  473. (*C.int)(blockOffsetsDev),
  474. (*C.int)(kvLensDev),
  475. (*C.int)(queryPosDev),
  476. (*C.float)(out),
  477. C.int(numTokens),
  478. C.int(numHeads), C.int(numKVHeads), C.int(headDim),
  479. C.int(blockSize),
  480. C.float(scale),
  481. C.int(maxKvLen),
  482. )
  483. if ret != 0 {
  484. return errors.New("cuda paged attention batch f16kv failed")
  485. }
  486. if err := syncIfProfiling(gpu); err != nil {
  487. return err
  488. }
  489. return nil
  490. }
  491. // PagedAttentionRoPEF32F16KV runs paged attention with fused RoPE inside the kernel.
  492. // Expects un-rotated Q and un-rotated K blocks; V blocks are unchanged.
  493. func PagedAttentionRoPEF32F16KV(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, theta float32, gpu int) error {
  494. if err := syncIfProfiling(gpu); err != nil {
  495. return err
  496. }
  497. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  498. return errors.New("failed to set cuda device")
  499. }
  500. ret := C.cuda_paged_attention_rope_f32_f16kv(
  501. (*C.float)(Q),
  502. (**C.ushort)(kBlocksDev),
  503. (**C.ushort)(vBlocksDev),
  504. (*C.float)(out),
  505. C.int(seqLen), C.int(kvLen),
  506. C.int(numHeads), C.int(numKVHeads), C.int(headDim),
  507. C.int(blockSize),
  508. C.float(scale), C.int(startPos),
  509. C.float(theta),
  510. )
  511. if ret != 0 {
  512. return errors.New("cuda paged attention rope f16kv failed")
  513. }
  514. if err := syncIfProfiling(gpu); err != nil {
  515. return err
  516. }
  517. return nil
  518. }
  519. // PagedAttentionBatchRoPEF32F16KV runs batched paged attention with fused RoPE inside the kernel.
  520. // Expects un-rotated Q and un-rotated K blocks; V blocks are unchanged.
  521. func PagedAttentionBatchRoPEF32F16KV(Q, kBlocksFlatDev, vBlocksFlatDev, blockOffsetsDev, kvLensDev, queryPosDev, out unsafe.Pointer, numTokens, numHeads, numKVHeads, headDim, blockSize int, scale float32, maxKvLen int, theta float32, gpu int) error {
  522. if err := syncIfProfiling(gpu); err != nil {
  523. return err
  524. }
  525. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  526. return errors.New("failed to set cuda device")
  527. }
  528. ret := C.cuda_paged_attention_rope_batch_f32_f16kv(
  529. (*C.float)(Q),
  530. (**C.ushort)(kBlocksFlatDev),
  531. (**C.ushort)(vBlocksFlatDev),
  532. (*C.int)(blockOffsetsDev),
  533. (*C.int)(kvLensDev),
  534. (*C.int)(queryPosDev),
  535. (*C.float)(out),
  536. C.int(numTokens),
  537. C.int(numHeads), C.int(numKVHeads), C.int(headDim),
  538. C.int(blockSize),
  539. C.float(scale),
  540. C.int(maxKvLen),
  541. C.float(theta),
  542. )
  543. if ret != 0 {
  544. return errors.New("cuda paged attention batch rope f16kv failed")
  545. }
  546. if err := syncIfProfiling(gpu); err != nil {
  547. return err
  548. }
  549. return nil
  550. }
  551. // Available returns whether CUDA is available
  552. func Available() bool {
  553. return true
  554. }
  555. // MemoryInfo returns (total, free) bytes for the current CUDA device.
  556. func MemoryInfo() (total uint64, free uint64, err error) {
  557. var cFree, cTotal C.size_t
  558. ret := C.cuda_mem_info(&cFree, &cTotal)
  559. if ret != 0 {
  560. return 0, 0, errors.New("cuda_mem_info failed")
  561. }
  562. return uint64(cTotal), uint64(cFree), nil
  563. }
  564. // MemoryInfoDevice returns (total, free) bytes for the given CUDA device.
  565. func MemoryInfoDevice(gpu int) (total uint64, free uint64, err error) {
  566. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  567. return 0, 0, errors.New("failed to set cuda device")
  568. }
  569. var cFree, cTotal C.size_t
  570. ret := C.cuda_mem_info(&cFree, &cTotal)
  571. if ret != 0 {
  572. return 0, 0, errors.New("cuda_mem_info failed")
  573. }
  574. return uint64(cTotal), uint64(cFree), nil
  575. }
  576. // DeviceCount returns the number of visible CUDA devices.
  577. func DeviceCount() (int, error) {
  578. var cCount C.int
  579. ret := C.cuda_device_count(&cCount)
  580. if ret != 0 {
  581. return 0, errors.New("cuda_device_count failed")
  582. }
  583. if cCount < 0 {
  584. return 0, errors.New("cuda_device_count returned negative")
  585. }
  586. return int(cCount), nil
  587. }
  588. // Synchronize waits for all queued work on the given GPU.
  589. // Use when explicit host/device coordination is required.
  590. func Synchronize(gpu int) error {
  591. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  592. return errors.New("failed to set cuda device")
  593. }
  594. if ret := C.cuda_synchronize(); ret != 0 {
  595. return errors.New("cuda synchronize failed")
  596. }
  597. return nil
  598. }
  599. // ============================================================
  600. // Neural Network Operations
  601. // ============================================================
  602. // RMSNorm applies RMS normalization in-place on GPU
  603. // x: [seqLen, dim] device pointer, w: [dim] device pointer
  604. func RMSNorm(x, w unsafe.Pointer, seqLen, dim int, eps float32, gpu int) error {
  605. if err := syncIfProfiling(gpu); err != nil {
  606. return err
  607. }
  608. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  609. return errors.New("failed to set cuda device")
  610. }
  611. ret := C.cuda_rmsnorm_f32((*C.float)(x), (*C.float)(w), C.int(seqLen), C.int(dim), C.float(eps))
  612. if ret != 0 {
  613. return errors.New("cuda rmsnorm failed")
  614. }
  615. if err := syncIfProfiling(gpu); err != nil {
  616. return err
  617. }
  618. return nil
  619. }
  620. // RoPE applies rotary positional embeddings in-place
  621. // x: [seqLen, numHeads * headDim] device pointer
  622. // positions: [seqLen] device pointer (int32)
  623. func RoPE(x, positions unsafe.Pointer, seqLen, numHeads, headDim int, theta float32, gpu int) error {
  624. if err := syncIfProfiling(gpu); err != nil {
  625. return err
  626. }
  627. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  628. return errors.New("failed to set cuda device")
  629. }
  630. ret := C.cuda_rope_f32((*C.float)(x), (*C.int)(positions), C.int(seqLen), C.int(numHeads), C.int(headDim), C.float(theta))
  631. if ret != 0 {
  632. return errors.New("cuda rope failed")
  633. }
  634. if err := syncIfProfiling(gpu); err != nil {
  635. return err
  636. }
  637. return nil
  638. }
  639. // RoPESingle runs RoPE for a single token at a specific position.
  640. func RoPESingle(x unsafe.Pointer, pos, numHeads, headDim int, theta float32, gpu int) error {
  641. if err := syncIfProfiling(gpu); err != nil {
  642. return err
  643. }
  644. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  645. return errors.New("failed to set cuda device")
  646. }
  647. ret := C.cuda_rope_f32_single((*C.float)(x), C.int(pos), C.int(numHeads), C.int(headDim), C.float(theta))
  648. if ret != 0 {
  649. return errors.New("cuda rope single failed")
  650. }
  651. if err := syncIfProfiling(gpu); err != nil {
  652. return err
  653. }
  654. return nil
  655. }
  656. // Softmax applies softmax along last dimension in-place
  657. func Softmax(x unsafe.Pointer, rows, cols int, gpu int) error {
  658. if err := syncIfProfiling(gpu); err != nil {
  659. return err
  660. }
  661. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  662. return errors.New("failed to set cuda device")
  663. }
  664. ret := C.cuda_softmax_f32((*C.float)(x), C.int(rows), C.int(cols))
  665. if ret != 0 {
  666. return errors.New("cuda softmax failed")
  667. }
  668. if err := syncIfProfiling(gpu); err != nil {
  669. return err
  670. }
  671. return nil
  672. }
  673. // SiLU applies SiLU activation in-place: x = x * sigmoid(x)
  674. func SiLU(x unsafe.Pointer, n int, gpu int) error {
  675. if err := syncIfProfiling(gpu); err != nil {
  676. return err
  677. }
  678. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  679. return errors.New("failed to set cuda device")
  680. }
  681. ret := C.cuda_silu_f32((*C.float)(x), C.size_t(n))
  682. if ret != 0 {
  683. return errors.New("cuda silu failed")
  684. }
  685. if err := syncIfProfiling(gpu); err != nil {
  686. return err
  687. }
  688. return nil
  689. }
  690. // MulInplace performs element-wise a = a * b
  691. func MulInplace(a, b unsafe.Pointer, n int, gpu int) error {
  692. if err := syncIfProfiling(gpu); err != nil {
  693. return err
  694. }
  695. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  696. return errors.New("failed to set cuda device")
  697. }
  698. ret := C.cuda_mul_inplace_f32((*C.float)(a), (*C.float)(b), C.size_t(n))
  699. if ret != 0 {
  700. return errors.New("cuda mul inplace failed")
  701. }
  702. if err := syncIfProfiling(gpu); err != nil {
  703. return err
  704. }
  705. return nil
  706. }
  707. // Copy copies GPU memory: dst = src
  708. func Copy(dst, src unsafe.Pointer, n int, gpu int) error {
  709. if err := syncIfProfiling(gpu); err != nil {
  710. return err
  711. }
  712. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  713. return errors.New("failed to set cuda device")
  714. }
  715. ret := C.cuda_copy_f32((*C.float)(dst), (*C.float)(src), C.size_t(n))
  716. if ret != 0 {
  717. return errors.New("cuda copy failed")
  718. }
  719. if err := syncIfProfiling(gpu); err != nil {
  720. return err
  721. }
  722. return nil
  723. }
  724. func KDACausalShortConv1D(x, state, w unsafe.Pointer, tokens, projSize, kernel int, gpu int) error {
  725. if err := syncIfProfiling(gpu); err != nil {
  726. return err
  727. }
  728. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  729. return errors.New("failed to set cuda device")
  730. }
  731. ret := C.cuda_kda_causal_short_conv1d_f32(
  732. (*C.float)(x),
  733. (*C.float)(state),
  734. (*C.float)(w),
  735. C.int(tokens),
  736. C.int(projSize),
  737. C.int(kernel),
  738. )
  739. if ret != 0 {
  740. return errors.New("cuda kda causal short conv1d failed")
  741. }
  742. if err := syncIfProfiling(gpu); err != nil {
  743. return err
  744. }
  745. return nil
  746. }
  747. func L2NormHeads(q, k unsafe.Pointer, tokens, numHeads, headDim int, eps float32, gpu int) error {
  748. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  749. return errors.New("failed to set cuda device")
  750. }
  751. ret := C.cuda_l2norm_heads_f32((*C.float)(q), (*C.float)(k), C.int(tokens), C.int(numHeads), C.int(headDim), C.float(eps))
  752. if ret != 0 {
  753. return errors.New("cuda l2norm heads failed")
  754. }
  755. return nil
  756. }
  757. func KDAGate(g, aLog, dtBias, out unsafe.Pointer, tokens, numHeads, headDim int, gpu int) error {
  758. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  759. return errors.New("failed to set cuda device")
  760. }
  761. ret := C.cuda_kda_gate_f32((*C.float)(g), (*C.float)(aLog), (*C.float)(dtBias), (*C.float)(out), C.int(tokens), C.int(numHeads), C.int(headDim))
  762. if ret != 0 {
  763. return errors.New("cuda kda gate failed")
  764. }
  765. return nil
  766. }
  767. func KDARecurrent(q, k, v, g, beta, state unsafe.Pointer, tokens, numHeads, headDim int, gpu int) error {
  768. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  769. return errors.New("failed to set cuda device")
  770. }
  771. ret := C.cuda_kda_recurrent_f32((*C.float)(q), (*C.float)(k), (*C.float)(v), (*C.float)(g), (*C.float)(beta), (*C.float)(state), C.int(tokens), C.int(numHeads), C.int(headDim))
  772. if ret != 0 {
  773. return errors.New("cuda kda recurrent failed")
  774. }
  775. return nil
  776. }
  777. func RMSNormGated(out, g, weight unsafe.Pointer, n, headDim int, eps float32, gpu int) error {
  778. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  779. return errors.New("failed to set cuda device")
  780. }
  781. ret := C.cuda_rmsnorm_gated_f32((*C.float)(out), (*C.float)(g), (*C.float)(weight), C.int(n), C.int(headDim), C.float(eps))
  782. if ret != 0 {
  783. return errors.New("cuda rmsnorm gated failed")
  784. }
  785. return nil
  786. }
  787. func Sigmoid(x unsafe.Pointer, n int, gpu int) error {
  788. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  789. return errors.New("failed to set cuda device")
  790. }
  791. ret := C.cuda_sigmoid_f32((*C.float)(x), C.int(n))
  792. if ret != 0 {
  793. return errors.New("cuda sigmoid failed")
  794. }
  795. return nil
  796. }
  797. func SoftmaxRows(x unsafe.Pointer, rows, cols int, gpu int) error {
  798. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  799. return errors.New("failed to set cuda device")
  800. }
  801. ret := C.cuda_softmax_rows_f32((*C.float)(x), C.int(rows), C.int(cols))
  802. if ret != 0 {
  803. return errors.New("cuda softmax rows failed")
  804. }
  805. return nil
  806. }
  807. func TopKPerRow(scores unsafe.Pointer, indices unsafe.Pointer, values unsafe.Pointer, rows, cols, k int, gpu int) error {
  808. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  809. return errors.New("failed to set cuda device")
  810. }
  811. ret := C.cuda_topk_per_row_f32((*C.float)(scores), (*C.int)(indices), (*C.float)(values), C.int(rows), C.int(cols), C.int(k))
  812. if ret != 0 {
  813. return errors.New("cuda topk per row failed")
  814. }
  815. return nil
  816. }
  817. // Attention computes full causal attention on GPU
  818. // Q: [seqLen, numHeads * headDim]
  819. // K, V: [kvLen, numKVHeads * headDim]
  820. // out: [seqLen, numHeads * headDim]
  821. func Attention(Q, K, V, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim int, scale float32, startPos int, gpu int) error {
  822. if err := syncIfProfiling(gpu); err != nil {
  823. return err
  824. }
  825. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  826. return errors.New("failed to set cuda device")
  827. }
  828. ret := C.cuda_attention_f32(
  829. (*C.float)(Q), (*C.float)(K), (*C.float)(V), (*C.float)(out),
  830. C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim),
  831. C.float(scale), C.int(startPos),
  832. )
  833. if ret != 0 {
  834. return errors.New("cuda attention failed")
  835. }
  836. if err := syncIfProfiling(gpu); err != nil {
  837. return err
  838. }
  839. return nil
  840. }
  841. func PagedAttention(Q, kBlocksDev, vBlocksDev, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize int, scale float32, startPos int, gpu int) error {
  842. if err := syncIfProfiling(gpu); err != nil {
  843. return err
  844. }
  845. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  846. return errors.New("failed to set cuda device")
  847. }
  848. ret := C.cuda_paged_attention_f32(
  849. (*C.float)(Q),
  850. (**C.float)(kBlocksDev),
  851. (**C.float)(vBlocksDev),
  852. (*C.float)(out),
  853. C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim),
  854. C.int(blockSize),
  855. C.float(scale), C.int(startPos),
  856. )
  857. if ret != 0 {
  858. return errors.New("cuda paged attention failed")
  859. }
  860. if err := syncIfProfiling(gpu); err != nil {
  861. return err
  862. }
  863. return nil
  864. }
  865. // AttentionTimed runs attention and returns kernel time in milliseconds.
  866. // Intended for profiling/debugging only (it synchronizes internally).
  867. func AttentionTimed(Q, K, V, out unsafe.Pointer, seqLen, kvLen, numHeads, numKVHeads, headDim int, scale float32, startPos int, gpu int) (float32, error) {
  868. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  869. return 0, errors.New("failed to set cuda device")
  870. }
  871. var ms C.float
  872. ret := C.cuda_attention_f32_timed(
  873. (*C.float)(Q), (*C.float)(K), (*C.float)(V), (*C.float)(out),
  874. C.int(seqLen), C.int(kvLen), C.int(numHeads), C.int(numKVHeads), C.int(headDim),
  875. C.float(scale), C.int(startPos), &ms,
  876. )
  877. if ret != 0 {
  878. return 0, errors.New("cuda attention timed failed")
  879. }
  880. return float32(ms), nil
  881. }
  882. // AddInplace performs element-wise a = a + b
  883. func AddInplace(a, b unsafe.Pointer, n int, gpu int) error {
  884. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  885. return errors.New("failed to set cuda device")
  886. }
  887. ret := C.cuda_add_f32((*C.float)(a), (*C.float)(b), C.size_t(n))
  888. if ret != 0 {
  889. return errors.New("cuda add failed")
  890. }
  891. return nil
  892. }
  893. // ============================================================
  894. // Dequantization Operations
  895. // ============================================================
  896. // DequantQ8K dequantizes Q8_K blocks on GPU
  897. // blocks: device pointer to Q8_K data
  898. // out: device pointer to output float32 (numBlocks * 256 elements)
  899. func DequantQ8K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
  900. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  901. return errors.New("failed to set cuda device")
  902. }
  903. ret := C.cuda_dequant_q8k(blocks, (*C.float)(out), C.int(numBlocks))
  904. if ret != 0 {
  905. return errors.New("cuda dequant q8k failed")
  906. }
  907. return nil
  908. }
  909. func DequantQ4K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
  910. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  911. return errors.New("failed to set cuda device")
  912. }
  913. ret := C.cuda_dequant_q4k(blocks, (*C.float)(out), C.int(numBlocks))
  914. if ret != 0 {
  915. return errors.New("cuda dequant q4k failed")
  916. }
  917. return nil
  918. }
  919. func DequantQ5K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
  920. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  921. return errors.New("failed to set cuda device")
  922. }
  923. ret := C.cuda_dequant_q5k(blocks, (*C.float)(out), C.int(numBlocks))
  924. if ret != 0 {
  925. return errors.New("cuda dequant q5k failed")
  926. }
  927. return nil
  928. }
  929. // DequantQ6K dequantizes Q6_K blocks on GPU
  930. func DequantQ6K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
  931. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  932. return errors.New("failed to set cuda device")
  933. }
  934. ret := C.cuda_dequant_q6k(blocks, (*C.float)(out), C.int(numBlocks))
  935. if ret != 0 {
  936. return errors.New("cuda dequant q6k failed")
  937. }
  938. return nil
  939. }
  940. func DequantQ3K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
  941. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  942. return errors.New("failed to set cuda device")
  943. }
  944. ret := C.cuda_dequant_q3k(blocks, (*C.float)(out), C.int(numBlocks))
  945. if ret != 0 {
  946. return errors.New("cuda dequant q3k failed")
  947. }
  948. return nil
  949. }
  950. func DequantQ2K(blocks unsafe.Pointer, out unsafe.Pointer, numBlocks int, gpu int) error {
  951. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  952. return errors.New("failed to set cuda device")
  953. }
  954. ret := C.cuda_dequant_q2k(blocks, (*C.float)(out), C.int(numBlocks))
  955. if ret != 0 {
  956. return errors.New("cuda dequant q2k failed")
  957. }
  958. return nil
  959. }
  960. // MatMulQ8K performs C = A @ dequant(B) where B is Q8_K quantized
  961. func MatMulQ8K(A unsafe.Pointer, B unsafe.Pointer, C unsafe.Pointer, M, K, N, gpu int) error {
  962. if err := syncIfProfiling(gpu); err != nil {
  963. return err
  964. }
  965. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  966. return errors.New("failed to set cuda device")
  967. }
  968. ret := C.cuda_matmul_f32_q8k((*C.float)(A), B, (*C.float)(C), C.int(M), C.int(K), C.int(N))
  969. if ret != 0 {
  970. return errors.New("cuda matmul q8k failed")
  971. }
  972. if err := syncIfProfiling(gpu); err != nil {
  973. return err
  974. }
  975. return nil
  976. }
  977. func MatMulQ5K(A unsafe.Pointer, B unsafe.Pointer, C unsafe.Pointer, M, K, N, gpu int) error {
  978. if err := syncIfProfiling(gpu); err != nil {
  979. return err
  980. }
  981. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  982. return errors.New("failed to set cuda device")
  983. }
  984. ret := C.cuda_matmul_f32_q5k((*C.float)(A), B, (*C.float)(C), C.int(M), C.int(K), C.int(N))
  985. if ret != 0 {
  986. return errors.New("cuda matmul q5k failed")
  987. }
  988. if err := syncIfProfiling(gpu); err != nil {
  989. return err
  990. }
  991. return nil
  992. }
  993. func MatMulQ4K(A unsafe.Pointer, B unsafe.Pointer, Cptr unsafe.Pointer, M, K, N, gpu int) error {
  994. if err := syncIfProfiling(gpu); err != nil {
  995. return err
  996. }
  997. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  998. return errors.New("failed to set cuda device")
  999. }
  1000. ret := C.cuda_matmul_f32_q4k((*C.float)(A), B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
  1001. if ret != 0 {
  1002. return errors.New("cuda matmul q4k failed")
  1003. }
  1004. if err := syncIfProfiling(gpu); err != nil {
  1005. return err
  1006. }
  1007. return nil
  1008. }
  1009. func MatMulQ2K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
  1010. if err := syncIfProfiling(gpu); err != nil {
  1011. return err
  1012. }
  1013. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1014. return errors.New("failed to set cuda device")
  1015. }
  1016. if k%256 != 0 {
  1017. return fmt.Errorf("MatMulQ2K: K must be multiple of 256, got %d", k)
  1018. }
  1019. ret := C.cuda_matmul_f32_q2k((*C.float)(aPtr), bPtr, (*C.float)(cPtr), C.int(m), C.int(k), C.int(n))
  1020. if ret != 0 {
  1021. return errors.New("cuda matmul q2k failed")
  1022. }
  1023. if err := syncIfProfiling(gpu); err != nil {
  1024. return err
  1025. }
  1026. return nil
  1027. }
  1028. func MatMulQ3K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
  1029. if err := syncIfProfiling(gpu); err != nil {
  1030. return err
  1031. }
  1032. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1033. return errors.New("failed to set cuda device")
  1034. }
  1035. if k%256 != 0 {
  1036. return fmt.Errorf("MatMulQ3K: K must be multiple of 256, got %d", k)
  1037. }
  1038. ret := C.cuda_matmul_f32_q3k((*C.float)(aPtr), bPtr, (*C.float)(cPtr), C.int(m), C.int(k), C.int(n))
  1039. if ret != 0 {
  1040. return errors.New("cuda matmul q3k failed")
  1041. }
  1042. if err := syncIfProfiling(gpu); err != nil {
  1043. return err
  1044. }
  1045. return nil
  1046. }
  1047. func MatMulQ6K(aPtr, bPtr, cPtr unsafe.Pointer, m, k, n int, gpu int) error {
  1048. if err := syncIfProfiling(gpu); err != nil {
  1049. return err
  1050. }
  1051. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1052. return errors.New("failed to set cuda device")
  1053. }
  1054. if k%256 != 0 {
  1055. return fmt.Errorf("MatMulQ6K: K must be multiple of 256, got %d", k)
  1056. }
  1057. ret := C.cuda_matmul_f32_q6k((*C.float)(aPtr), bPtr, (*C.float)(cPtr), C.int(m), C.int(k), C.int(n))
  1058. if ret != 0 {
  1059. return errors.New("cuda matmul q6k failed")
  1060. }
  1061. if err := syncIfProfiling(gpu); err != nil {
  1062. return err
  1063. }
  1064. return nil
  1065. }
  1066. func MatMulF32(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
  1067. if err := syncIfProfiling(gpu); err != nil {
  1068. return err
  1069. }
  1070. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1071. return errors.New("failed to set cuda device")
  1072. }
  1073. ret := C.cuda_matmul_f32_nt(
  1074. (*C.float)(A),
  1075. (*C.float)(B),
  1076. (*C.float)(Cptr),
  1077. C.int(M), C.int(K), C.int(N),
  1078. )
  1079. if ret != 0 {
  1080. return errors.New("cuda matmul f32 failed")
  1081. }
  1082. if err := syncIfProfiling(gpu); err != nil {
  1083. return err
  1084. }
  1085. return nil
  1086. }
  1087. // MatMulF16 performs C = A @ B^T where A and B are float16 (stored as uint16),
  1088. // and C is float32 output.
  1089. func MatMulF16(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
  1090. if err := syncIfProfiling(gpu); err != nil {
  1091. return err
  1092. }
  1093. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1094. return errors.New("failed to set cuda device")
  1095. }
  1096. ret := C.cuda_matmul_f16_nt(
  1097. (*C.ushort)(A),
  1098. (*C.ushort)(B),
  1099. (*C.float)(Cptr),
  1100. C.int(M), C.int(K), C.int(N),
  1101. )
  1102. if ret != 0 {
  1103. return errors.New("cuda matmul f16 failed")
  1104. }
  1105. if err := syncIfProfiling(gpu); err != nil {
  1106. return err
  1107. }
  1108. return nil
  1109. }
  1110. // FP16 Input MatMul variants - 2x memory bandwidth for activations
  1111. // A is FP16, B is quantized, C is FP32 output
  1112. func MatMulF16Q8K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
  1113. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1114. return errors.New("failed to set cuda device")
  1115. }
  1116. ret := C.cuda_matmul_f16_q8k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
  1117. if ret != 0 {
  1118. return errors.New("cuda matmul f16 q8k failed")
  1119. }
  1120. return nil
  1121. }
  1122. func MatMulF16Q4K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
  1123. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1124. return errors.New("failed to set cuda device")
  1125. }
  1126. ret := C.cuda_matmul_f16_q4k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
  1127. if ret != 0 {
  1128. return errors.New("cuda matmul f16 q4k failed")
  1129. }
  1130. return nil
  1131. }
  1132. func MatMulF16Q5K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
  1133. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1134. return errors.New("failed to set cuda device")
  1135. }
  1136. ret := C.cuda_matmul_f16_q5k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
  1137. if ret != 0 {
  1138. return errors.New("cuda matmul f16 q5k failed")
  1139. }
  1140. return nil
  1141. }
  1142. func MatMulF16Q2K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
  1143. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1144. return errors.New("failed to set cuda device")
  1145. }
  1146. ret := C.cuda_matmul_f16_q2k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
  1147. if ret != 0 {
  1148. return errors.New("cuda matmul f16 q2k failed")
  1149. }
  1150. return nil
  1151. }
  1152. func MatMulF16Q3K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
  1153. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1154. return errors.New("failed to set cuda device")
  1155. }
  1156. ret := C.cuda_matmul_f16_q3k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
  1157. if ret != 0 {
  1158. return errors.New("cuda matmul f16 q3k failed")
  1159. }
  1160. return nil
  1161. }
  1162. func MatMulF16Q6K(A, B, Cptr unsafe.Pointer, M, K, N, gpu int) error {
  1163. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1164. return errors.New("failed to set cuda device")
  1165. }
  1166. ret := C.cuda_matmul_f16_q6k(A, B, (*C.float)(Cptr), C.int(M), C.int(K), C.int(N))
  1167. if ret != 0 {
  1168. return errors.New("cuda matmul f16 q6k failed")
  1169. }
  1170. return nil
  1171. }
  1172. // UploadQ8K uploads Q8_K blocks from host to GPU
  1173. func UploadQ8K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
  1174. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1175. return nil, errors.New("failed to set cuda device")
  1176. }
  1177. size := len(hostData)
  1178. ptr := C.cuda_malloc(C.size_t(size))
  1179. if ptr == nil {
  1180. return nil, errors.New("cuda malloc failed for Q8K")
  1181. }
  1182. ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
  1183. if ret != 0 {
  1184. C.cuda_free(ptr)
  1185. return nil, errors.New("cuda memcpy h2d failed for Q8K")
  1186. }
  1187. return ptr, nil
  1188. }
  1189. func AllocAndCopyPtrTable(ptrs []uintptr, gpu int) (unsafe.Pointer, error) {
  1190. if len(ptrs) == 0 {
  1191. return nil, errors.New("empty ptr table")
  1192. }
  1193. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1194. return nil, errors.New("failed to set cuda device")
  1195. }
  1196. size := len(ptrs) * int(unsafe.Sizeof(uintptr(0)))
  1197. ptr := C.cuda_malloc(C.size_t(size))
  1198. if ptr == nil {
  1199. return nil, errors.New("cuda malloc failed for ptr table")
  1200. }
  1201. ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&ptrs[0]), C.size_t(size))
  1202. if ret != 0 {
  1203. C.cuda_free(ptr)
  1204. return nil, errors.New("cuda memcpy h2d failed for ptr table")
  1205. }
  1206. return ptr, nil
  1207. }
  1208. func UploadQ5K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
  1209. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1210. return nil, errors.New("failed to set cuda device")
  1211. }
  1212. size := len(hostData)
  1213. ptr := C.cuda_malloc(C.size_t(size))
  1214. if ptr == nil {
  1215. return nil, errors.New("cuda malloc failed for Q5K")
  1216. }
  1217. ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
  1218. if ret != 0 {
  1219. C.cuda_free(ptr)
  1220. return nil, errors.New("cuda memcpy h2d failed for Q5K")
  1221. }
  1222. return ptr, nil
  1223. }
  1224. // UploadQ4K uploads Q4_K blocks from host to GPU
  1225. func UploadQ4K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
  1226. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1227. return nil, errors.New("failed to set cuda device")
  1228. }
  1229. size := len(hostData)
  1230. ptr := C.cuda_malloc(C.size_t(size))
  1231. if ptr == nil {
  1232. return nil, errors.New("cuda malloc failed for Q4K")
  1233. }
  1234. ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
  1235. if ret != 0 {
  1236. C.cuda_free(ptr)
  1237. return nil, errors.New("cuda memcpy h2d failed for Q4K")
  1238. }
  1239. return ptr, nil
  1240. }
  1241. // UploadQ2K uploads Q2_K blocks from host to GPU
  1242. func UploadQ2K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
  1243. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1244. return nil, errors.New("failed to set cuda device")
  1245. }
  1246. size := len(hostData)
  1247. ptr := C.cuda_malloc(C.size_t(size))
  1248. if ptr == nil {
  1249. return nil, errors.New("cuda malloc failed for Q2K")
  1250. }
  1251. ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
  1252. if ret != 0 {
  1253. C.cuda_free(ptr)
  1254. return nil, errors.New("cuda memcpy h2d failed for Q2K")
  1255. }
  1256. return ptr, nil
  1257. }
  1258. // UploadQ3K uploads Q3_K blocks from host to GPU
  1259. func UploadQ3K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
  1260. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1261. return nil, errors.New("failed to set cuda device")
  1262. }
  1263. size := len(hostData)
  1264. ptr := C.cuda_malloc(C.size_t(size))
  1265. if ptr == nil {
  1266. return nil, errors.New("cuda malloc failed for Q3K")
  1267. }
  1268. ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
  1269. if ret != 0 {
  1270. C.cuda_free(ptr)
  1271. return nil, errors.New("cuda memcpy h2d failed for Q3K")
  1272. }
  1273. return ptr, nil
  1274. }
  1275. // UploadQ6K uploads Q6_K blocks from host to GPU
  1276. func UploadQ6K(hostData []byte, numBlocks int, gpu int) (unsafe.Pointer, error) {
  1277. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1278. return nil, errors.New("failed to set cuda device")
  1279. }
  1280. size := len(hostData)
  1281. ptr := C.cuda_malloc(C.size_t(size))
  1282. if ptr == nil {
  1283. return nil, errors.New("cuda malloc failed for Q6K")
  1284. }
  1285. ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&hostData[0]), C.size_t(size))
  1286. if ret != 0 {
  1287. C.cuda_free(ptr)
  1288. return nil, errors.New("cuda memcpy h2d failed for Q6K")
  1289. }
  1290. return ptr, nil
  1291. }
  1292. // MemcpyH2D copies data from host to device pointer.
  1293. // dst: device pointer
  1294. // src: host data (unsafe.Pointer to first element)
  1295. // size: number of bytes
  1296. // gpu: device id (must be active or will be set)
  1297. func MemcpyH2D(dst, src unsafe.Pointer, size uintptr, gpu int) error {
  1298. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1299. return errors.New("failed to set cuda device")
  1300. }
  1301. ret := C.cuda_memcpy_h2d(dst, src, C.size_t(size))
  1302. if ret != 0 {
  1303. return errors.New("cuda memcpy h2d failed")
  1304. }
  1305. return nil
  1306. }
  1307. // MemcpyD2H copies data from device pointer to host pointer.
  1308. func MemcpyD2H(dst, src unsafe.Pointer, size uintptr, gpu int) error {
  1309. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1310. return errors.New("failed to set cuda device")
  1311. }
  1312. ret := C.cuda_memcpy_d2h(dst, src, C.size_t(size))
  1313. if ret != 0 {
  1314. return errors.New("cuda memcpy d2h failed")
  1315. }
  1316. return nil
  1317. }
  1318. func MemcpyD2D(dst, src unsafe.Pointer, size uintptr, gpu int) error {
  1319. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1320. return errors.New("failed to set cuda device")
  1321. }
  1322. ret := C.cuda_memcpy_d2d(dst, src, C.size_t(size))
  1323. if ret != 0 {
  1324. return errors.New("cuda memcpy d2d failed")
  1325. }
  1326. return nil
  1327. }
  1328. // TopKLogitsF32 computes per-block top-k on GPU (with repetition penalty applied)
  1329. // and returns the concatenated candidate list on host (caller does final global top-k).
  1330. func TopKLogitsF32(logits unsafe.Pointer, vocab int, repIDs []int32, repPenalty float32, k int, gpu int) ([]int32, []float32, int, error) {
  1331. if k <= 0 {
  1332. return nil, nil, 0, nil
  1333. }
  1334. if k > 64 {
  1335. return nil, nil, 0, fmt.Errorf("TopKLogitsF32: k too large: %d", k)
  1336. }
  1337. blocks := (vocab + 2048 - 1) / 2048
  1338. if blocks <= 0 {
  1339. blocks = 1
  1340. }
  1341. count := blocks * k
  1342. var repPtr unsafe.Pointer
  1343. if len(repIDs) > 0 {
  1344. p, err := AllocAndCopyInt32(repIDs, gpu)
  1345. if err != nil {
  1346. return nil, nil, 0, err
  1347. }
  1348. repPtr = p
  1349. defer FreeDevicePtr(repPtr)
  1350. }
  1351. // Device outputs
  1352. outIDsPtr := C.cuda_malloc(C.size_t(count * 4))
  1353. if outIDsPtr == nil {
  1354. return nil, nil, 0, errors.New("TopKLogitsF32: cuda malloc failed for outIDs")
  1355. }
  1356. defer C.cuda_free(outIDsPtr)
  1357. outScoresPtr := C.cuda_malloc(C.size_t(count * 4))
  1358. if outScoresPtr == nil {
  1359. return nil, nil, 0, errors.New("TopKLogitsF32: cuda malloc failed for outScores")
  1360. }
  1361. defer C.cuda_free(outScoresPtr)
  1362. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1363. return nil, nil, 0, errors.New("failed to set cuda device")
  1364. }
  1365. ret := C.cuda_topk_logits_f32(
  1366. (*C.float)(logits),
  1367. C.int(vocab),
  1368. (*C.int)(repPtr),
  1369. C.int(len(repIDs)),
  1370. C.float(repPenalty),
  1371. C.int(k),
  1372. (*C.int)(outIDsPtr),
  1373. (*C.float)(outScoresPtr),
  1374. )
  1375. if ret != 0 {
  1376. return nil, nil, 0, errors.New("cuda topk logits failed")
  1377. }
  1378. ids := make([]int32, count)
  1379. scores := make([]float32, count)
  1380. if err := MemcpyD2H(unsafe.Pointer(&ids[0]), unsafe.Pointer(outIDsPtr), uintptr(count*4), gpu); err != nil {
  1381. return nil, nil, 0, err
  1382. }
  1383. if err := MemcpyD2H(unsafe.Pointer(&scores[0]), unsafe.Pointer(outScoresPtr), uintptr(count*4), gpu); err != nil {
  1384. return nil, nil, 0, err
  1385. }
  1386. return ids, scores, blocks, nil
  1387. }
  1388. // FreeDevicePtr frees a device pointer
  1389. func FreeDevicePtr(ptr unsafe.Pointer) {
  1390. if ptr != nil {
  1391. C.cuda_free(ptr)
  1392. }
  1393. }
  1394. // Free is an alias for FreeDevicePtr for convenience
  1395. func Free(ptr unsafe.Pointer) {
  1396. FreeDevicePtr(ptr)
  1397. }
  1398. // AllocAndCopyInt32 allocates GPU memory and copies int32 data to it
  1399. // Returns raw device pointer (caller must Free it)
  1400. func AllocAndCopyInt32(data []int32, gpu int) (unsafe.Pointer, error) {
  1401. if len(data) == 0 {
  1402. return nil, errors.New("empty data")
  1403. }
  1404. if ret := C.cuda_set_device(C.int(gpu)); ret != 0 {
  1405. return nil, errors.New("failed to set cuda device")
  1406. }
  1407. size := len(data) * 4 // 4 bytes per int32
  1408. ptr := C.cuda_malloc(C.size_t(size))
  1409. if ptr == nil {
  1410. return nil, errors.New("cuda malloc failed for int32 data")
  1411. }
  1412. ret := C.cuda_memcpy_h2d(ptr, unsafe.Pointer(&data[0]), C.size_t(size))
  1413. if ret != 0 {
  1414. C.cuda_free(ptr)
  1415. return nil, errors.New("cuda memcpy h2d failed for int32 data")
  1416. }
  1417. return ptr, nil
  1418. }