#ifndef MAKARNA_CUDA_H #define MAKARNA_CUDA_H #include #ifdef __cplusplus extern "C" { #endif // Memory Management int cuda_set_device(int id); void* cuda_malloc(size_t size); void cuda_free(void* ptr); int cuda_synchronize(); int cuda_memcpy_h2d(void* dst, void* src, size_t size); int cuda_memcpy_d2h(void* dst, void* src, size_t size); int cuda_memcpy_d2d(void* dst, void* src, size_t size); int cuda_mem_info(size_t* free_bytes, size_t* total_bytes); int cuda_device_count(int* count); // Math Operations (Float32) // Launches kernels on the default stream int cuda_add_f32(float* a, float* b, size_t n); int cuda_mul_f32(float* a, float* b, size_t n); int cuda_matmul_f32(float* A, float* B, float* C, int M, int K, int N); // MatMul where B is row-major [N, K] (no host transpose needed). int cuda_matmul_f32_nt(float* A, float* B, float* C, int M, int K, int N); // MatMul where A and B are float16 (IEEE half stored as uint16). // B is row-major [N, K] and interpreted as column-major [K, N]. int cuda_matmul_f16_nt(const unsigned short* A, const unsigned short* B, float* C, int M, int K, int N); // ============================================================ // Neural Network Operations // ============================================================ // RMSNorm: x = x * rsqrt(mean(x^2) + eps) * weight // x: [seqLen, dim], w: [dim] -> modifies x in-place int cuda_rmsnorm_f32(float* x, const float* w, int seqLen, int dim, float eps); // RoPE: Apply rotary positional embeddings in-place // x: [seqLen, numHeads * headDim] // positions: [seqLen] - position indices int cuda_rope_f32(float* x, const int* positions, int seqLen, int numHeads, int headDim, float theta); int cuda_rope_f32_single(float* x, int pos, int numHeads, int headDim, float theta); // Softmax: Apply softmax along last dimension // x: [rows, cols] -> in-place int cuda_softmax_f32(float* x, int rows, int cols); // Top-K selection on logits with optional repetition penalty. // logits: [vocab] // rep_ids: [rep_count] token ids to penalize // out_ids/out_scores: [numBlocks * k] // Returns 0 on success. int cuda_topk_logits_f32( const float* logits, int vocab, const int* rep_ids, int rep_count, float rep_penalty, int k, int* out_ids, float* out_scores); // Causal Attention: Full attention computation // Q: [seqLen, numHeads * headDim] // K: [kvLen, numKVHeads * headDim] // V: [kvLen, numKVHeads * headDim] // out: [seqLen, numHeads * headDim] // scale: typically 1/sqrt(headDim) // startPos: for causal mask offset (KV cache) int cuda_attention_f32( const float* Q, const float* K, const float* V, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, float scale, int startPos); int cuda_paged_attention_f32( const float* Q, const float* const* KBlocks, const float* const* VBlocks, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos); int cuda_paged_attention_batch_f32( const float* Q, const float* const* KBlocksFlat, const float* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int maxKvLen); // Paged attention where KV blocks are float16 (IEEE half stored as uint16). // Q and out are float32. Accumulation is float32. int cuda_paged_attention_f32_f16kv( const float* Q, const unsigned short* const* KBlocks, const unsigned short* const* VBlocks, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos); int cuda_paged_attention_batch_f32_f16kv( const float* Q, const unsigned short* const* KBlocksFlat, const unsigned short* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int maxKvLen); // Fused RoPE + paged attention where KV blocks are float16 (IEEE half stored as uint16). // Expects un-rotated Q and un-rotated K blocks; RoPE is applied on-the-fly in the attention kernel. int cuda_paged_attention_rope_f32_f16kv( const float* Q, const unsigned short* const* KBlocks, const unsigned short* const* VBlocks, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos, float theta); int cuda_paged_attention_rope_batch_f32_f16kv( const float* Q, const unsigned short* const* KBlocksFlat, const unsigned short* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int maxKvLen, float theta); // Cast float32 -> float16 (stored as uint16) on GPU. int cuda_cast_f32_to_f16(const float* src, unsigned short* dst, int n); int cuda_attention_f32_timed( const float* Q, const float* K, const float* V, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, float scale, int startPos, float* ms); // SiLU activation: x = x * sigmoid(x), in-place int cuda_silu_f32(float* x, size_t n); // Element-wise multiply: a = a * b, in-place int cuda_mul_inplace_f32(float* a, const float* b, size_t n); // Copy: dst = src int cuda_copy_f32(float* dst, const float* src, size_t n); int cuda_kda_causal_short_conv1d_f32( float* x, float* state, const float* w, int tokens, int projSize, int kernel); int cuda_l2norm_heads_f32( float* q, float* k, int tokens, int numHeads, int headDim, float eps); int cuda_kda_gate_f32( const float* g, const float* aLog, const float* dtBias, float* out, int tokens, int numHeads, int headDim); int cuda_kda_recurrent_f32( const float* q, const float* k, float* v, const float* g, // beta is a device pointer: [tokens, numHeads] (row-major). const float* beta, float* state, int tokens, int numHeads, int headDim); int cuda_rmsnorm_gated_f32( float* out, const float* g, const float* weight, int n, int headDim, float eps); int cuda_sigmoid_f32(float* x, int n); int cuda_softmax_rows_f32(float* x, int rows, int cols); int cuda_topk_per_row_f32( const float* scores, int* indices, float* values, int rows, int cols, int k); // ============================================================ // Dequantization Kernels // These convert quantized blocks to float32 on GPU // ============================================================ // Block sizes for K-quantization #define QK_K 256 // BlockQ8_K: 292 bytes per block (4 + 256 + 32) // - D (4 bytes): float32 scale // - QS (256 bytes): 256 int8 quants // - BSums (32 bytes): unused in dequant typedef struct { float d; signed char qs[256]; short bsums[16]; } BlockQ8_K; // BlockQ4_K: 144 bytes per block // - D (2 bytes): float16 super-scale // - DMin (2 bytes): float16 super-min // - Scales (12 bytes): packed 6-bit scales/mins // - QS (128 bytes): 256 4-bit quants typedef struct { unsigned short d; unsigned short dmin; unsigned char scales[12]; unsigned char qs[128]; } BlockQ4_K; typedef struct { unsigned short d; unsigned short dmin; unsigned char scales[12]; unsigned char qh[32]; unsigned char qs[128]; } BlockQ5_K; // BlockQ6_K: 210 bytes per block // - QL (128 bytes): lower 4 bits // - QH (64 bytes): upper 2 bits // - Scales (16 bytes): 8-bit scales // - D (2 bytes): float16 super-scale typedef struct { unsigned char ql[128]; unsigned char qh[64]; signed char scales[16]; unsigned short d; } BlockQ6_K; // BlockQ3_K: 110 bytes per block // - HMask (32 bytes): high bits // - QS (64 bytes): low 2 bits // - Scales (12 bytes): packed 6-bit scales // - D (2 bytes): float16 super-scale typedef struct { unsigned char hmask[32]; unsigned char qs[64]; unsigned char scales[12]; unsigned short d; } BlockQ3_K; // BlockQ2_K: 84 bytes per block // - Scales (16 bytes): packed 4-bit scales/mins // - QS (64 bytes): 256 2-bit quants // - D (2 bytes): float16 super-scale // - DMin (2 bytes): float16 super-min typedef struct { unsigned char scales[16]; unsigned char qs[64]; unsigned short d; unsigned short dmin; } BlockQ2_K; // Dequantize a row of Q8_K blocks: numBlocks * 256 values -> out int cuda_dequant_q8k(const void* blocks, float* out, int numBlocks); // Dequantize a row of Q4_K blocks int cuda_dequant_q4k(const void* blocks, float* out, int numBlocks); int cuda_dequant_q5k(const void* blocks, float* out, int numBlocks); // Dequantize a row of Q6_K blocks int cuda_dequant_q6k(const void* blocks, float* out, int numBlocks); // Dequantize a row of Q3_K blocks int cuda_dequant_q3k(const void* blocks, float* out, int numBlocks); // Dequantize a row of Q2_K blocks int cuda_dequant_q2k(const void* blocks, float* out, int numBlocks); // Fused Dequant + MatMul (for maximum performance) // A: [M, K] float32 input // B: quantized weight blocks [N rows, K/256 blocks per row] // C: [M, N] float32 output // This dequantizes B on-the-fly and computes C = A @ B.T int cuda_matmul_f32_q8k(float* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f32_q5k(float* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f32_q6k(float* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f32_q4k(float* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f32_q3k(float* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f32_q2k(float* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f32_q8k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms); int cuda_matmul_f32_q4k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms); // FP16 Input Variants - 2x memory bandwidth for activations // A: [M, K] float16 input, B: quantized, C: [M, N] float32 output int cuda_matmul_f16_q8k(const void* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f16_q4k(const void* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f16_q5k(const void* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f16_q2k(const void* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f16_q3k(const void* A, const void* B, float* C, int M, int K, int N); int cuda_matmul_f16_q6k(const void* A, const void* B, float* C, int M, int K, int N); // Debug helper int cuda_print_struct_sizes(); #ifdef __cplusplus } #endif #endif // MAKARNA_CUDA_H