| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557 |
- #include "cuda_common.cuh"
- // --- Kernels ---
- __global__ void add_kernel(float* a, const float* b, int n) {
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if (idx < n) {
- a[idx] += b[idx];
- }
- }
- int cuda_add_f32(float* a, float* b, size_t n) {
- int threads = 256;
- int blocks = (int)((n + threads - 1) / threads);
- add_kernel<<<blocks, threads>>>(a, b, (int)n);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- __global__ void mul_kernel(float* a, float* b, int n) {
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if (idx < n) {
- a[idx] *= b[idx];
- }
- }
- int cuda_mul_f32(float* a, float* b, size_t n) {
- int threads = 256;
- int blocks = (n + threads - 1) / threads;
- mul_kernel<<<blocks, threads>>>(a, b, n);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // SiLU kernel: x = x * sigmoid(x)
- __global__ void silu_kernel(float* x, int n) {
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if (idx < n) {
- float val = x[idx];
- x[idx] = val / (1.0f + __expf(-val));
- }
- }
- int cuda_silu_f32(float* x, size_t n) {
- int threads = 256;
- int blocks = (n + threads - 1) / threads;
- silu_kernel<<<blocks, threads>>>(x, n);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // Element-wise multiply in-place
- __global__ void mul_inplace_kernel(float* a, const float* b, int n) {
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if (idx < n) {
- a[idx] *= b[idx];
- }
- }
- int cuda_mul_inplace_f32(float* a, const float* b, size_t n) {
- int threads = 256;
- int blocks = (n + threads - 1) / threads;
- mul_inplace_kernel<<<blocks, threads>>>(a, b, n);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // Copy kernel
- int cuda_copy_f32(float* dst, const float* src, size_t n) {
- CHECK_CUDA(cudaMemcpy(dst, src, n * sizeof(float), cudaMemcpyDeviceToDevice));
- return 0;
- }
- // ============================================================
- // KDA: Causal short conv1d + SiLU
- // ============================================================
- static __device__ __forceinline__ float sigmoid_f32(float x) {
- return 1.0f / (1.0f + __expf(-x));
- }
- static __device__ __forceinline__ float silu_f32(float x) {
- return x * sigmoid_f32(x);
- }
- // xTok: [projSize]
- // state: [projSize, convLen]
- // w: [projSize, kernel] (assumed contiguous)
- __global__ void kda_causal_short_conv1d_token_kernel(
- float* xTok,
- float* state,
- const float* w,
- int projSize,
- int kernel,
- int convLen
- ) {
- int d = blockIdx.x * blockDim.x + threadIdx.x;
- if (d >= projSize) {
- return;
- }
- const int wBase = d * kernel;
- const int stBase = d * convLen;
- // Read input before overwriting xTok.
- const float xIn = xTok[d];
- float acc = 0.0f;
- for (int j = 0; j < convLen; j++) {
- acc = fmaf(w[wBase + j], state[stBase + j], acc);
- }
- acc = fmaf(w[wBase + convLen], xIn, acc);
- xTok[d] = silu_f32(acc);
- // Update causal state: shift left and append xIn.
- if (convLen > 0) {
- for (int j = 0; j < convLen - 1; j++) {
- state[stBase + j] = state[stBase + j + 1];
- }
- state[stBase + convLen - 1] = xIn;
- }
- }
- int cuda_kda_causal_short_conv1d_f32(
- float* x,
- float* state,
- const float* w,
- int tokens,
- int projSize,
- int kernel
- ) {
- if (tokens <= 0 || projSize <= 0) {
- return 0;
- }
- if (kernel <= 1) {
- // Just SiLU.
- return cuda_silu_f32(x, (size_t)tokens * (size_t)projSize);
- }
- const int convLen = kernel - 1;
- int threads = 256;
- int blocks = (projSize + threads - 1) / threads;
- for (int t = 0; t < tokens; t++) {
- float* xTok = x + (size_t)t * (size_t)projSize;
- kda_causal_short_conv1d_token_kernel<<<blocks, threads>>>(xTok, state, w, projSize, kernel, convLen);
- CHECK_CUDA(cudaGetLastError());
- }
- return 0;
- }
- // ============================================================
- // KDA: L2 Norm Heads
- // ============================================================
- __global__ void kda_l2norm_head_kernel(float* x, int headDim, float eps) {
- // One block per head segment
- extern __shared__ float sdata[];
-
- int tid = threadIdx.x;
- float* head = x + blockIdx.x * headDim;
-
- // Compute sum of squares
- float sum = 0.0f;
- for (int i = tid; i < headDim; i += blockDim.x) {
- float v = head[i];
- sum += v * v;
- }
- sdata[tid] = sum;
- __syncthreads();
-
- // Reduce
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (tid < s) {
- sdata[tid] += sdata[tid + s];
- }
- __syncthreads();
- }
-
- float invNorm = rsqrtf(sdata[0] + eps);
-
- // Normalize
- for (int i = tid; i < headDim; i += blockDim.x) {
- head[i] *= invNorm;
- }
- }
- int cuda_l2norm_heads_f32(float* q, float* k, int tokens, int numHeads, int headDim, float eps) {
- if (tokens <= 0 || numHeads <= 0 || headDim <= 0) return 0;
-
- int totalHeads = tokens * numHeads;
- int threads = min(256, headDim);
- size_t sharedMem = threads * sizeof(float);
-
- kda_l2norm_head_kernel<<<totalHeads, threads, sharedMem>>>(q, headDim, eps);
- CHECK_CUDA(cudaGetLastError());
- kda_l2norm_head_kernel<<<totalHeads, threads, sharedMem>>>(k, headDim, eps);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // KDA: Gate computation
- // g_out = -exp(aLog[h]) * softplus(g + dtBias)
- // ============================================================
- __device__ __forceinline__ float softplus_f32(float x) {
- return (x > 20.0f) ? x : logf(1.0f + __expf(x));
- }
- __global__ void kda_gate_kernel(
- const float* g,
- const float* aLog,
- const float* dtBias,
- float* out,
- int numHeads,
- int headDim
- ) {
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
- int projSize = numHeads * headDim;
- if (idx >= projSize) return;
-
- int h = idx / headDim;
- float mul = -__expf(aLog[h]);
- float x = g[idx];
- if (dtBias != nullptr) {
- x += dtBias[idx];
- }
- out[idx] = mul * softplus_f32(x);
- }
- int cuda_kda_gate_f32(
- const float* g,
- const float* aLog,
- const float* dtBias,
- float* out,
- int tokens,
- int numHeads,
- int headDim
- ) {
- if (tokens <= 0) return 0;
- int projSize = numHeads * headDim;
- int threads = 256;
- int blocks = (projSize + threads - 1) / threads;
-
- for (int t = 0; t < tokens; t++) {
- const float* gTok = g + t * projSize;
- float* outTok = out + t * projSize;
- kda_gate_kernel<<<blocks, threads>>>(gTok, aLog, dtBias, outTok, numHeads, headDim);
- CHECK_CUDA(cudaGetLastError());
- }
- return 0;
- }
- // ============================================================
- // KDA: Recurrent (per-token, per-head)
- // state[h]: [headDim, headDim]
- // ============================================================
- __global__ void kda_recurrent_step_kernel(
- const float* qTok,
- const float* kTok,
- float* vTok,
- const float* gTok,
- const float* betaTok,
- float* state,
- int numHeads,
- int headDim,
- float scale
- ) {
- // One block per head (blockIdx.x), threads work on headDim elements.
- extern __shared__ float shared[];
- float* tmpKV = shared;
- float* tmpVM = shared + headDim;
- int h = blockIdx.x;
- if (h >= numHeads) return;
- int tid = threadIdx.x;
- int stateStride = headDim * headDim;
- int off = h * headDim;
- const float* q = qTok + off;
- const float* k = kTok + off;
- float* v = vTok + off;
- const float* g = gTok + off;
- float beta = betaTok[h];
- float* st = state + h * stateStride;
- // Step 1: Decay state by exp(g)
- for (int kk = tid; kk < headDim; kk += blockDim.x) {
- float dec = __expf(g[kk]);
- for (int vv = 0; vv < headDim; vv++) {
- st[kk * headDim + vv] *= dec;
- }
- }
- __syncthreads();
- // Step 2: tmpKV = k^T @ state (for each v dimension)
- for (int vv = tid; vv < headDim; vv += blockDim.x) {
- float acc = 0.0f;
- for (int kk = 0; kk < headDim; kk++) {
- acc += k[kk] * st[kk * headDim + vv];
- }
- tmpKV[vv] = acc;
- }
- __syncthreads();
- // Step 3: tmpVM = v - tmpKV
- for (int vv = tid; vv < headDim; vv += blockDim.x) {
- tmpVM[vv] = v[vv] - tmpKV[vv];
- }
- __syncthreads();
- // Step 4: state += beta * k @ tmpVM^T
- for (int kk = tid; kk < headDim; kk += blockDim.x) {
- float kj = beta * k[kk];
- for (int vv = 0; vv < headDim; vv++) {
- st[kk * headDim + vv] += kj * tmpVM[vv];
- }
- }
- __syncthreads();
- // Step 5: v = (q * scale)^T @ state
- for (int vv = tid; vv < headDim; vv += blockDim.x) {
- float acc = 0.0f;
- for (int kk = 0; kk < headDim; kk++) {
- acc += (q[kk] * scale) * st[kk * headDim + vv];
- }
- v[vv] = acc;
- }
- }
- int cuda_kda_recurrent_f32(
- const float* q,
- const float* k,
- float* v,
- const float* g,
- const float* beta,
- float* state,
- int tokens,
- int numHeads,
- int headDim
- ) {
- if (tokens <= 0 || numHeads <= 0 || headDim <= 0) return 0;
-
- int projSize = numHeads * headDim;
- float scale = 1.0f / sqrtf((float)headDim);
-
- int threads = min(256, headDim);
- size_t sharedMem = 2 * headDim * sizeof(float);
-
- for (int t = 0; t < tokens; t++) {
- const float* qTok = q + t * projSize;
- const float* kTok = k + t * projSize;
- float* vTok = v + t * projSize;
- const float* gTok = g + t * projSize;
- const float* betaTok = beta + t * numHeads;
- kda_recurrent_step_kernel<<<numHeads, threads, sharedMem>>>(
- qTok, kTok, vTok, gTok, betaTok, state, numHeads, headDim, scale
- );
- CHECK_CUDA(cudaGetLastError());
- }
- return 0;
- }
- // ============================================================
- // KDA: RMSNorm Gated
- // out = (out / rms) * weight * sigmoid(g)
- // ============================================================
- __global__ void kda_rmsnorm_gated_kernel(
- float* out,
- const float* g,
- const float* weight,
- int headDim,
- float eps
- ) {
- extern __shared__ float sdata[];
-
- int tid = threadIdx.x;
- float* head = out + blockIdx.x * headDim;
- const float* gHead = g ? (g + blockIdx.x * headDim) : nullptr;
-
- // Compute sum of squares
- float sum = 0.0f;
- for (int i = tid; i < headDim; i += blockDim.x) {
- float v = head[i];
- sum += v * v;
- }
- sdata[tid] = sum;
- __syncthreads();
-
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (tid < s) {
- sdata[tid] += sdata[tid + s];
- }
- __syncthreads();
- }
-
- float inv = rsqrtf(sdata[0] / (float)headDim + eps);
-
- for (int i = tid; i < headDim; i += blockDim.x) {
- float y = head[i] * inv * weight[i];
- if (gHead != nullptr) {
- y *= 1.0f / (1.0f + __expf(-gHead[i])); // sigmoid
- }
- head[i] = y;
- }
- }
- int cuda_rmsnorm_gated_f32(
- float* out,
- const float* g,
- const float* weight,
- int n,
- int headDim,
- float eps
- ) {
- if (n <= 0 || headDim <= 0) return 0;
-
- int numHeads = n / headDim;
- int threads = min(256, headDim);
- size_t sharedMem = threads * sizeof(float);
-
- kda_rmsnorm_gated_kernel<<<numHeads, threads, sharedMem>>>(out, g, weight, headDim, eps);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // Sigmoid (for MoE router, etc.)
- // ============================================================
- __global__ void sigmoid_kernel(float* x, int n) {
- int idx = blockIdx.x * blockDim.x + threadIdx.x;
- if (idx < n) {
- x[idx] = 1.0f / (1.0f + __expf(-x[idx]));
- }
- }
- int cuda_sigmoid_f32(float* x, int n) {
- if (n <= 0) return 0;
- int threads = 256;
- int blocks = (n + threads - 1) / threads;
- sigmoid_kernel<<<blocks, threads>>>(x, n);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // Softmax per row (for MoE router)
- // ============================================================
- __global__ void softmax_row_kernel(float* x, int cols) {
- extern __shared__ float sdata[];
- int row = blockIdx.x;
- int tid = threadIdx.x;
- float* rowData = x + row * cols;
-
- // Find max
- float maxVal = -1e30f;
- for (int i = tid; i < cols; i += blockDim.x) {
- maxVal = fmaxf(maxVal, rowData[i]);
- }
- sdata[tid] = maxVal;
- __syncthreads();
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
- __syncthreads();
- }
- maxVal = sdata[0];
- __syncthreads();
-
- // Compute exp and sum
- float sum = 0.0f;
- for (int i = tid; i < cols; i += blockDim.x) {
- float v = __expf(rowData[i] - maxVal);
- rowData[i] = v;
- sum += v;
- }
- sdata[tid] = sum;
- __syncthreads();
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (tid < s) sdata[tid] += sdata[tid + s];
- __syncthreads();
- }
- float invSum = 1.0f / sdata[0];
-
- // Normalize
- for (int i = tid; i < cols; i += blockDim.x) {
- rowData[i] *= invSum;
- }
- }
- int cuda_softmax_rows_f32(float* x, int rows, int cols) {
- if (rows <= 0 || cols <= 0) return 0;
- int threads = min(256, cols);
- size_t sharedMem = threads * sizeof(float);
- softmax_row_kernel<<<rows, threads, sharedMem>>>(x, cols);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // TopK per row (for MoE expert selection)
- // ============================================================
- __global__ void topk_per_row_kernel(
- const float* scores,
- int* indices,
- float* values,
- int cols,
- int k
- ) {
- int row = blockIdx.x;
- const float* rowScores = scores + row * cols;
- int* rowIndices = indices + row * k;
- float* rowValues = values + row * k;
-
- // Simple O(n*k) selection - good enough for small k
- for (int i = 0; i < k; i++) {
- float bestVal = -1e30f;
- int bestIdx = -1;
- for (int j = 0; j < cols; j++) {
- float v = rowScores[j];
- // Check if already selected
- bool selected = false;
- for (int p = 0; p < i; p++) {
- if (rowIndices[p] == j) { selected = true; break; }
- }
- if (!selected && v > bestVal) {
- bestVal = v;
- bestIdx = j;
- }
- }
- rowIndices[i] = bestIdx;
- rowValues[i] = bestVal;
- }
- }
- int cuda_topk_per_row_f32(
- const float* scores,
- int* indices,
- float* values,
- int rows,
- int cols,
- int k
- ) {
- if (rows <= 0 || cols <= 0 || k <= 0) return 0;
- topk_per_row_kernel<<<rows, 1>>>(scores, indices, values, cols, k);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
|