#include "cuda_common.cuh" // --- Kernels --- __global__ void add_kernel(float* a, const float* b, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { a[idx] += b[idx]; } } int cuda_add_f32(float* a, float* b, size_t n) { int threads = 256; int blocks = (int)((n + threads - 1) / threads); add_kernel<<>>(a, b, (int)n); CHECK_CUDA(cudaGetLastError()); return 0; } __global__ void mul_kernel(float* a, float* b, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { a[idx] *= b[idx]; } } int cuda_mul_f32(float* a, float* b, size_t n) { int threads = 256; int blocks = (n + threads - 1) / threads; mul_kernel<<>>(a, b, n); CHECK_CUDA(cudaGetLastError()); return 0; } // SiLU kernel: x = x * sigmoid(x) __global__ void silu_kernel(float* x, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { float val = x[idx]; x[idx] = val / (1.0f + __expf(-val)); } } int cuda_silu_f32(float* x, size_t n) { int threads = 256; int blocks = (n + threads - 1) / threads; silu_kernel<<>>(x, n); CHECK_CUDA(cudaGetLastError()); return 0; } // Element-wise multiply in-place __global__ void mul_inplace_kernel(float* a, const float* b, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { a[idx] *= b[idx]; } } int cuda_mul_inplace_f32(float* a, const float* b, size_t n) { int threads = 256; int blocks = (n + threads - 1) / threads; mul_inplace_kernel<<>>(a, b, n); CHECK_CUDA(cudaGetLastError()); return 0; } // Copy kernel int cuda_copy_f32(float* dst, const float* src, size_t n) { CHECK_CUDA(cudaMemcpy(dst, src, n * sizeof(float), cudaMemcpyDeviceToDevice)); return 0; } // ============================================================ // KDA: Causal short conv1d + SiLU // ============================================================ static __device__ __forceinline__ float sigmoid_f32(float x) { return 1.0f / (1.0f + __expf(-x)); } static __device__ __forceinline__ float silu_f32(float x) { return x * sigmoid_f32(x); } // xTok: [projSize] // state: [projSize, convLen] // w: [projSize, kernel] (assumed contiguous) __global__ void kda_causal_short_conv1d_token_kernel( float* xTok, float* state, const float* w, int projSize, int kernel, int convLen ) { int d = blockIdx.x * blockDim.x + threadIdx.x; if (d >= projSize) { return; } const int wBase = d * kernel; const int stBase = d * convLen; // Read input before overwriting xTok. const float xIn = xTok[d]; float acc = 0.0f; for (int j = 0; j < convLen; j++) { acc = fmaf(w[wBase + j], state[stBase + j], acc); } acc = fmaf(w[wBase + convLen], xIn, acc); xTok[d] = silu_f32(acc); // Update causal state: shift left and append xIn. if (convLen > 0) { for (int j = 0; j < convLen - 1; j++) { state[stBase + j] = state[stBase + j + 1]; } state[stBase + convLen - 1] = xIn; } } int cuda_kda_causal_short_conv1d_f32( float* x, float* state, const float* w, int tokens, int projSize, int kernel ) { if (tokens <= 0 || projSize <= 0) { return 0; } if (kernel <= 1) { // Just SiLU. return cuda_silu_f32(x, (size_t)tokens * (size_t)projSize); } const int convLen = kernel - 1; int threads = 256; int blocks = (projSize + threads - 1) / threads; for (int t = 0; t < tokens; t++) { float* xTok = x + (size_t)t * (size_t)projSize; kda_causal_short_conv1d_token_kernel<<>>(xTok, state, w, projSize, kernel, convLen); CHECK_CUDA(cudaGetLastError()); } return 0; } // ============================================================ // KDA: L2 Norm Heads // ============================================================ __global__ void kda_l2norm_head_kernel(float* x, int headDim, float eps) { // One block per head segment extern __shared__ float sdata[]; int tid = threadIdx.x; float* head = x + blockIdx.x * headDim; // Compute sum of squares float sum = 0.0f; for (int i = tid; i < headDim; i += blockDim.x) { float v = head[i]; sum += v * v; } sdata[tid] = sum; __syncthreads(); // Reduce for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { sdata[tid] += sdata[tid + s]; } __syncthreads(); } float invNorm = rsqrtf(sdata[0] + eps); // Normalize for (int i = tid; i < headDim; i += blockDim.x) { head[i] *= invNorm; } } int cuda_l2norm_heads_f32(float* q, float* k, int tokens, int numHeads, int headDim, float eps) { if (tokens <= 0 || numHeads <= 0 || headDim <= 0) return 0; int totalHeads = tokens * numHeads; int threads = min(256, headDim); size_t sharedMem = threads * sizeof(float); kda_l2norm_head_kernel<<>>(q, headDim, eps); CHECK_CUDA(cudaGetLastError()); kda_l2norm_head_kernel<<>>(k, headDim, eps); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // KDA: Gate computation // g_out = -exp(aLog[h]) * softplus(g + dtBias) // ============================================================ __device__ __forceinline__ float softplus_f32(float x) { return (x > 20.0f) ? x : logf(1.0f + __expf(x)); } __global__ void kda_gate_kernel( const float* g, const float* aLog, const float* dtBias, float* out, int numHeads, int headDim ) { int idx = blockIdx.x * blockDim.x + threadIdx.x; int projSize = numHeads * headDim; if (idx >= projSize) return; int h = idx / headDim; float mul = -__expf(aLog[h]); float x = g[idx]; if (dtBias != nullptr) { x += dtBias[idx]; } out[idx] = mul * softplus_f32(x); } int cuda_kda_gate_f32( const float* g, const float* aLog, const float* dtBias, float* out, int tokens, int numHeads, int headDim ) { if (tokens <= 0) return 0; int projSize = numHeads * headDim; int threads = 256; int blocks = (projSize + threads - 1) / threads; for (int t = 0; t < tokens; t++) { const float* gTok = g + t * projSize; float* outTok = out + t * projSize; kda_gate_kernel<<>>(gTok, aLog, dtBias, outTok, numHeads, headDim); CHECK_CUDA(cudaGetLastError()); } return 0; } // ============================================================ // KDA: Recurrent (per-token, per-head) // state[h]: [headDim, headDim] // ============================================================ __global__ void kda_recurrent_step_kernel( const float* qTok, const float* kTok, float* vTok, const float* gTok, const float* betaTok, float* state, int numHeads, int headDim, float scale ) { // One block per head (blockIdx.x), threads work on headDim elements. extern __shared__ float shared[]; float* tmpKV = shared; float* tmpVM = shared + headDim; int h = blockIdx.x; if (h >= numHeads) return; int tid = threadIdx.x; int stateStride = headDim * headDim; int off = h * headDim; const float* q = qTok + off; const float* k = kTok + off; float* v = vTok + off; const float* g = gTok + off; float beta = betaTok[h]; float* st = state + h * stateStride; // Step 1: Decay state by exp(g) for (int kk = tid; kk < headDim; kk += blockDim.x) { float dec = __expf(g[kk]); for (int vv = 0; vv < headDim; vv++) { st[kk * headDim + vv] *= dec; } } __syncthreads(); // Step 2: tmpKV = k^T @ state (for each v dimension) for (int vv = tid; vv < headDim; vv += blockDim.x) { float acc = 0.0f; for (int kk = 0; kk < headDim; kk++) { acc += k[kk] * st[kk * headDim + vv]; } tmpKV[vv] = acc; } __syncthreads(); // Step 3: tmpVM = v - tmpKV for (int vv = tid; vv < headDim; vv += blockDim.x) { tmpVM[vv] = v[vv] - tmpKV[vv]; } __syncthreads(); // Step 4: state += beta * k @ tmpVM^T for (int kk = tid; kk < headDim; kk += blockDim.x) { float kj = beta * k[kk]; for (int vv = 0; vv < headDim; vv++) { st[kk * headDim + vv] += kj * tmpVM[vv]; } } __syncthreads(); // Step 5: v = (q * scale)^T @ state for (int vv = tid; vv < headDim; vv += blockDim.x) { float acc = 0.0f; for (int kk = 0; kk < headDim; kk++) { acc += (q[kk] * scale) * st[kk * headDim + vv]; } v[vv] = acc; } } int cuda_kda_recurrent_f32( const float* q, const float* k, float* v, const float* g, const float* beta, float* state, int tokens, int numHeads, int headDim ) { if (tokens <= 0 || numHeads <= 0 || headDim <= 0) return 0; int projSize = numHeads * headDim; float scale = 1.0f / sqrtf((float)headDim); int threads = min(256, headDim); size_t sharedMem = 2 * headDim * sizeof(float); for (int t = 0; t < tokens; t++) { const float* qTok = q + t * projSize; const float* kTok = k + t * projSize; float* vTok = v + t * projSize; const float* gTok = g + t * projSize; const float* betaTok = beta + t * numHeads; kda_recurrent_step_kernel<<>>( qTok, kTok, vTok, gTok, betaTok, state, numHeads, headDim, scale ); CHECK_CUDA(cudaGetLastError()); } return 0; } // ============================================================ // KDA: RMSNorm Gated // out = (out / rms) * weight * sigmoid(g) // ============================================================ __global__ void kda_rmsnorm_gated_kernel( float* out, const float* g, const float* weight, int headDim, float eps ) { extern __shared__ float sdata[]; int tid = threadIdx.x; float* head = out + blockIdx.x * headDim; const float* gHead = g ? (g + blockIdx.x * headDim) : nullptr; // Compute sum of squares float sum = 0.0f; for (int i = tid; i < headDim; i += blockDim.x) { float v = head[i]; sum += v * v; } sdata[tid] = sum; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) { sdata[tid] += sdata[tid + s]; } __syncthreads(); } float inv = rsqrtf(sdata[0] / (float)headDim + eps); for (int i = tid; i < headDim; i += blockDim.x) { float y = head[i] * inv * weight[i]; if (gHead != nullptr) { y *= 1.0f / (1.0f + __expf(-gHead[i])); // sigmoid } head[i] = y; } } int cuda_rmsnorm_gated_f32( float* out, const float* g, const float* weight, int n, int headDim, float eps ) { if (n <= 0 || headDim <= 0) return 0; int numHeads = n / headDim; int threads = min(256, headDim); size_t sharedMem = threads * sizeof(float); kda_rmsnorm_gated_kernel<<>>(out, g, weight, headDim, eps); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // Sigmoid (for MoE router, etc.) // ============================================================ __global__ void sigmoid_kernel(float* x, int n) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) { x[idx] = 1.0f / (1.0f + __expf(-x[idx])); } } int cuda_sigmoid_f32(float* x, int n) { if (n <= 0) return 0; int threads = 256; int blocks = (n + threads - 1) / threads; sigmoid_kernel<<>>(x, n); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // Softmax per row (for MoE router) // ============================================================ __global__ void softmax_row_kernel(float* x, int cols) { extern __shared__ float sdata[]; int row = blockIdx.x; int tid = threadIdx.x; float* rowData = x + row * cols; // Find max float maxVal = -1e30f; for (int i = tid; i < cols; i += blockDim.x) { maxVal = fmaxf(maxVal, rowData[i]); } sdata[tid] = maxVal; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]); __syncthreads(); } maxVal = sdata[0]; __syncthreads(); // Compute exp and sum float sum = 0.0f; for (int i = tid; i < cols; i += blockDim.x) { float v = __expf(rowData[i] - maxVal); rowData[i] = v; sum += v; } sdata[tid] = sum; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (tid < s) sdata[tid] += sdata[tid + s]; __syncthreads(); } float invSum = 1.0f / sdata[0]; // Normalize for (int i = tid; i < cols; i += blockDim.x) { rowData[i] *= invSum; } } int cuda_softmax_rows_f32(float* x, int rows, int cols) { if (rows <= 0 || cols <= 0) return 0; int threads = min(256, cols); size_t sharedMem = threads * sizeof(float); softmax_row_kernel<<>>(x, cols); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // TopK per row (for MoE expert selection) // ============================================================ __global__ void topk_per_row_kernel( const float* scores, int* indices, float* values, int cols, int k ) { int row = blockIdx.x; const float* rowScores = scores + row * cols; int* rowIndices = indices + row * k; float* rowValues = values + row * k; // Simple O(n*k) selection - good enough for small k for (int i = 0; i < k; i++) { float bestVal = -1e30f; int bestIdx = -1; for (int j = 0; j < cols; j++) { float v = rowScores[j]; // Check if already selected bool selected = false; for (int p = 0; p < i; p++) { if (rowIndices[p] == j) { selected = true; break; } } if (!selected && v > bestVal) { bestVal = v; bestIdx = j; } } rowIndices[i] = bestIdx; rowValues[i] = bestVal; } } int cuda_topk_per_row_f32( const float* scores, int* indices, float* values, int rows, int cols, int k ) { if (rows <= 0 || cols <= 0 || k <= 0) return 0; topk_per_row_kernel<<>>(scores, indices, values, cols, k); CHECK_CUDA(cudaGetLastError()); return 0; }