1
0

kernels.h 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. #ifndef MAKARNA_CUDA_H
  2. #define MAKARNA_CUDA_H
  3. #include <stddef.h>
  4. #ifdef __cplusplus
  5. extern "C" {
  6. #endif
  7. // Memory Management
  8. int cuda_set_device(int id);
  9. void* cuda_malloc(size_t size);
  10. void cuda_free(void* ptr);
  11. int cuda_synchronize();
  12. int cuda_memcpy_h2d(void* dst, void* src, size_t size);
  13. int cuda_memcpy_d2h(void* dst, void* src, size_t size);
  14. int cuda_memcpy_d2d(void* dst, void* src, size_t size);
  15. int cuda_mem_info(size_t* free_bytes, size_t* total_bytes);
  16. int cuda_device_count(int* count);
  17. // Math Operations (Float32)
  18. // Launches kernels on the default stream
  19. int cuda_add_f32(float* a, float* b, size_t n);
  20. int cuda_mul_f32(float* a, float* b, size_t n);
  21. int cuda_matmul_f32(float* A, float* B, float* C, int M, int K, int N);
  22. // MatMul where B is row-major [N, K] (no host transpose needed).
  23. int cuda_matmul_f32_nt(float* A, float* B, float* C, int M, int K, int N);
  24. // MatMul where A and B are float16 (IEEE half stored as uint16).
  25. // B is row-major [N, K] and interpreted as column-major [K, N].
  26. int cuda_matmul_f16_nt(const unsigned short* A, const unsigned short* B, float* C, int M, int K, int N);
  27. // ============================================================
  28. // Neural Network Operations
  29. // ============================================================
  30. // RMSNorm: x = x * rsqrt(mean(x^2) + eps) * weight
  31. // x: [seqLen, dim], w: [dim] -> modifies x in-place
  32. int cuda_rmsnorm_f32(float* x, const float* w, int seqLen, int dim, float eps);
  33. // RoPE: Apply rotary positional embeddings in-place
  34. // x: [seqLen, numHeads * headDim]
  35. // positions: [seqLen] - position indices
  36. int cuda_rope_f32(float* x, const int* positions, int seqLen, int numHeads, int headDim, float theta);
  37. int cuda_rope_f32_single(float* x, int pos, int numHeads, int headDim, float theta);
  38. // Softmax: Apply softmax along last dimension
  39. // x: [rows, cols] -> in-place
  40. int cuda_softmax_f32(float* x, int rows, int cols);
  41. // Top-K selection on logits with optional repetition penalty.
  42. // logits: [vocab]
  43. // rep_ids: [rep_count] token ids to penalize
  44. // out_ids/out_scores: [numBlocks * k]
  45. // Returns 0 on success.
  46. int cuda_topk_logits_f32(
  47. const float* logits, int vocab,
  48. const int* rep_ids, int rep_count, float rep_penalty,
  49. int k,
  50. int* out_ids, float* out_scores);
  51. // Causal Attention: Full attention computation
  52. // Q: [seqLen, numHeads * headDim]
  53. // K: [kvLen, numKVHeads * headDim]
  54. // V: [kvLen, numKVHeads * headDim]
  55. // out: [seqLen, numHeads * headDim]
  56. // scale: typically 1/sqrt(headDim)
  57. // startPos: for causal mask offset (KV cache)
  58. int cuda_attention_f32(
  59. const float* Q, const float* K, const float* V, float* out,
  60. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  61. float scale, int startPos);
  62. int cuda_paged_attention_f32(
  63. const float* Q,
  64. const float* const* KBlocks,
  65. const float* const* VBlocks,
  66. float* out,
  67. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  68. int blockSize,
  69. float scale, int startPos);
  70. int cuda_paged_attention_batch_f32(
  71. const float* Q,
  72. const float* const* KBlocksFlat,
  73. const float* const* VBlocksFlat,
  74. const int* blockOffsets,
  75. const int* kvLens,
  76. const int* queryPos,
  77. float* out,
  78. int numTokens,
  79. int numHeads, int numKVHeads, int headDim,
  80. int blockSize,
  81. float scale,
  82. int maxKvLen);
  83. // Paged attention where KV blocks are float16 (IEEE half stored as uint16).
  84. // Q and out are float32. Accumulation is float32.
  85. int cuda_paged_attention_f32_f16kv(
  86. const float* Q,
  87. const unsigned short* const* KBlocks,
  88. const unsigned short* const* VBlocks,
  89. float* out,
  90. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  91. int blockSize,
  92. float scale, int startPos);
  93. int cuda_paged_attention_batch_f32_f16kv(
  94. const float* Q,
  95. const unsigned short* const* KBlocksFlat,
  96. const unsigned short* const* VBlocksFlat,
  97. const int* blockOffsets,
  98. const int* kvLens,
  99. const int* queryPos,
  100. float* out,
  101. int numTokens,
  102. int numHeads, int numKVHeads, int headDim,
  103. int blockSize,
  104. float scale,
  105. int maxKvLen);
  106. // Fused RoPE + paged attention where KV blocks are float16 (IEEE half stored as uint16).
  107. // Expects un-rotated Q and un-rotated K blocks; RoPE is applied on-the-fly in the attention kernel.
  108. int cuda_paged_attention_rope_f32_f16kv(
  109. const float* Q,
  110. const unsigned short* const* KBlocks,
  111. const unsigned short* const* VBlocks,
  112. float* out,
  113. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  114. int blockSize,
  115. float scale, int startPos,
  116. float theta);
  117. int cuda_paged_attention_rope_batch_f32_f16kv(
  118. const float* Q,
  119. const unsigned short* const* KBlocksFlat,
  120. const unsigned short* const* VBlocksFlat,
  121. const int* blockOffsets,
  122. const int* kvLens,
  123. const int* queryPos,
  124. float* out,
  125. int numTokens,
  126. int numHeads, int numKVHeads, int headDim,
  127. int blockSize,
  128. float scale,
  129. int maxKvLen,
  130. float theta);
  131. // Cast float32 -> float16 (stored as uint16) on GPU.
  132. int cuda_cast_f32_to_f16(const float* src, unsigned short* dst, int n);
  133. int cuda_attention_f32_timed(
  134. const float* Q, const float* K, const float* V, float* out,
  135. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  136. float scale, int startPos, float* ms);
  137. // SiLU activation: x = x * sigmoid(x), in-place
  138. int cuda_silu_f32(float* x, size_t n);
  139. // Element-wise multiply: a = a * b, in-place
  140. int cuda_mul_inplace_f32(float* a, const float* b, size_t n);
  141. // Copy: dst = src
  142. int cuda_copy_f32(float* dst, const float* src, size_t n);
  143. int cuda_kda_causal_short_conv1d_f32(
  144. float* x,
  145. float* state,
  146. const float* w,
  147. int tokens,
  148. int projSize,
  149. int kernel);
  150. int cuda_l2norm_heads_f32(
  151. float* q,
  152. float* k,
  153. int tokens,
  154. int numHeads,
  155. int headDim,
  156. float eps);
  157. int cuda_kda_gate_f32(
  158. const float* g,
  159. const float* aLog,
  160. const float* dtBias,
  161. float* out,
  162. int tokens,
  163. int numHeads,
  164. int headDim);
  165. int cuda_kda_recurrent_f32(
  166. const float* q,
  167. const float* k,
  168. float* v,
  169. const float* g,
  170. // beta is a device pointer: [tokens, numHeads] (row-major).
  171. const float* beta,
  172. float* state,
  173. int tokens,
  174. int numHeads,
  175. int headDim);
  176. int cuda_rmsnorm_gated_f32(
  177. float* out,
  178. const float* g,
  179. const float* weight,
  180. int n,
  181. int headDim,
  182. float eps);
  183. int cuda_sigmoid_f32(float* x, int n);
  184. int cuda_softmax_rows_f32(float* x, int rows, int cols);
  185. int cuda_topk_per_row_f32(
  186. const float* scores,
  187. int* indices,
  188. float* values,
  189. int rows,
  190. int cols,
  191. int k);
  192. // ============================================================
  193. // Dequantization Kernels
  194. // These convert quantized blocks to float32 on GPU
  195. // ============================================================
  196. // Block sizes for K-quantization
  197. #define QK_K 256
  198. // BlockQ8_K: 292 bytes per block (4 + 256 + 32)
  199. // - D (4 bytes): float32 scale
  200. // - QS (256 bytes): 256 int8 quants
  201. // - BSums (32 bytes): unused in dequant
  202. typedef struct {
  203. float d;
  204. signed char qs[256];
  205. short bsums[16];
  206. } BlockQ8_K;
  207. // BlockQ4_K: 144 bytes per block
  208. // - D (2 bytes): float16 super-scale
  209. // - DMin (2 bytes): float16 super-min
  210. // - Scales (12 bytes): packed 6-bit scales/mins
  211. // - QS (128 bytes): 256 4-bit quants
  212. typedef struct {
  213. unsigned short d;
  214. unsigned short dmin;
  215. unsigned char scales[12];
  216. unsigned char qs[128];
  217. } BlockQ4_K;
  218. typedef struct {
  219. unsigned short d;
  220. unsigned short dmin;
  221. unsigned char scales[12];
  222. unsigned char qh[32];
  223. unsigned char qs[128];
  224. } BlockQ5_K;
  225. // BlockQ6_K: 210 bytes per block
  226. // - QL (128 bytes): lower 4 bits
  227. // - QH (64 bytes): upper 2 bits
  228. // - Scales (16 bytes): 8-bit scales
  229. // - D (2 bytes): float16 super-scale
  230. typedef struct {
  231. unsigned char ql[128];
  232. unsigned char qh[64];
  233. signed char scales[16];
  234. unsigned short d;
  235. } BlockQ6_K;
  236. // BlockQ3_K: 110 bytes per block
  237. // - HMask (32 bytes): high bits
  238. // - QS (64 bytes): low 2 bits
  239. // - Scales (12 bytes): packed 6-bit scales
  240. // - D (2 bytes): float16 super-scale
  241. typedef struct {
  242. unsigned char hmask[32];
  243. unsigned char qs[64];
  244. unsigned char scales[12];
  245. unsigned short d;
  246. } BlockQ3_K;
  247. // BlockQ2_K: 84 bytes per block
  248. // - Scales (16 bytes): packed 4-bit scales/mins
  249. // - QS (64 bytes): 256 2-bit quants
  250. // - D (2 bytes): float16 super-scale
  251. // - DMin (2 bytes): float16 super-min
  252. typedef struct {
  253. unsigned char scales[16];
  254. unsigned char qs[64];
  255. unsigned short d;
  256. unsigned short dmin;
  257. } BlockQ2_K;
  258. // Dequantize a row of Q8_K blocks: numBlocks * 256 values -> out
  259. int cuda_dequant_q8k(const void* blocks, float* out, int numBlocks);
  260. // Dequantize a row of Q4_K blocks
  261. int cuda_dequant_q4k(const void* blocks, float* out, int numBlocks);
  262. int cuda_dequant_q5k(const void* blocks, float* out, int numBlocks);
  263. // Dequantize a row of Q6_K blocks
  264. int cuda_dequant_q6k(const void* blocks, float* out, int numBlocks);
  265. // Dequantize a row of Q3_K blocks
  266. int cuda_dequant_q3k(const void* blocks, float* out, int numBlocks);
  267. // Dequantize a row of Q2_K blocks
  268. int cuda_dequant_q2k(const void* blocks, float* out, int numBlocks);
  269. // Fused Dequant + MatMul (for maximum performance)
  270. // A: [M, K] float32 input
  271. // B: quantized weight blocks [N rows, K/256 blocks per row]
  272. // C: [M, N] float32 output
  273. // This dequantizes B on-the-fly and computes C = A @ B.T
  274. int cuda_matmul_f32_q8k(float* A, const void* B, float* C, int M, int K, int N);
  275. int cuda_matmul_f32_q5k(float* A, const void* B, float* C, int M, int K, int N);
  276. int cuda_matmul_f32_q6k(float* A, const void* B, float* C, int M, int K, int N);
  277. int cuda_matmul_f32_q4k(float* A, const void* B, float* C, int M, int K, int N);
  278. int cuda_matmul_f32_q3k(float* A, const void* B, float* C, int M, int K, int N);
  279. int cuda_matmul_f32_q2k(float* A, const void* B, float* C, int M, int K, int N);
  280. int cuda_matmul_f32_q8k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms);
  281. int cuda_matmul_f32_q4k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms);
  282. // FP16 Input Variants - 2x memory bandwidth for activations
  283. // A: [M, K] float16 input, B: quantized, C: [M, N] float32 output
  284. int cuda_matmul_f16_q8k(const void* A, const void* B, float* C, int M, int K, int N);
  285. int cuda_matmul_f16_q4k(const void* A, const void* B, float* C, int M, int K, int N);
  286. int cuda_matmul_f16_q5k(const void* A, const void* B, float* C, int M, int K, int N);
  287. int cuda_matmul_f16_q2k(const void* A, const void* B, float* C, int M, int K, int N);
  288. int cuda_matmul_f16_q3k(const void* A, const void* B, float* C, int M, int K, int N);
  289. int cuda_matmul_f16_q6k(const void* A, const void* B, float* C, int M, int K, int N);
  290. // Debug helper
  291. int cuda_print_struct_sizes();
  292. #ifdef __cplusplus
  293. }
  294. #endif
  295. #endif // MAKARNA_CUDA_H