1
0

cuda_elementwise.cu 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. #include "cuda_common.cuh"
  2. // --- Kernels ---
  3. __global__ void add_kernel(float* a, const float* b, int n) {
  4. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  5. if (idx < n) {
  6. a[idx] += b[idx];
  7. }
  8. }
  9. int cuda_add_f32(float* a, float* b, size_t n) {
  10. int threads = 256;
  11. int blocks = (int)((n + threads - 1) / threads);
  12. add_kernel<<<blocks, threads>>>(a, b, (int)n);
  13. CHECK_CUDA(cudaGetLastError());
  14. return 0;
  15. }
  16. __global__ void mul_kernel(float* a, float* b, int n) {
  17. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  18. if (idx < n) {
  19. a[idx] *= b[idx];
  20. }
  21. }
  22. int cuda_mul_f32(float* a, float* b, size_t n) {
  23. int threads = 256;
  24. int blocks = (n + threads - 1) / threads;
  25. mul_kernel<<<blocks, threads>>>(a, b, n);
  26. CHECK_CUDA(cudaGetLastError());
  27. return 0;
  28. }
  29. // SiLU kernel: x = x * sigmoid(x)
  30. __global__ void silu_kernel(float* x, int n) {
  31. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  32. if (idx < n) {
  33. float val = x[idx];
  34. x[idx] = val / (1.0f + __expf(-val));
  35. }
  36. }
  37. int cuda_silu_f32(float* x, size_t n) {
  38. int threads = 256;
  39. int blocks = (n + threads - 1) / threads;
  40. silu_kernel<<<blocks, threads>>>(x, n);
  41. CHECK_CUDA(cudaGetLastError());
  42. return 0;
  43. }
  44. // Element-wise multiply in-place
  45. __global__ void mul_inplace_kernel(float* a, const float* b, int n) {
  46. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  47. if (idx < n) {
  48. a[idx] *= b[idx];
  49. }
  50. }
  51. int cuda_mul_inplace_f32(float* a, const float* b, size_t n) {
  52. int threads = 256;
  53. int blocks = (n + threads - 1) / threads;
  54. mul_inplace_kernel<<<blocks, threads>>>(a, b, n);
  55. CHECK_CUDA(cudaGetLastError());
  56. return 0;
  57. }
  58. // Copy kernel
  59. int cuda_copy_f32(float* dst, const float* src, size_t n) {
  60. CHECK_CUDA(cudaMemcpy(dst, src, n * sizeof(float), cudaMemcpyDeviceToDevice));
  61. return 0;
  62. }
  63. // ============================================================
  64. // KDA: Causal short conv1d + SiLU
  65. // ============================================================
  66. static __device__ __forceinline__ float sigmoid_f32(float x) {
  67. return 1.0f / (1.0f + __expf(-x));
  68. }
  69. static __device__ __forceinline__ float silu_f32(float x) {
  70. return x * sigmoid_f32(x);
  71. }
  72. // xTok: [projSize]
  73. // state: [projSize, convLen]
  74. // w: [projSize, kernel] (assumed contiguous)
  75. __global__ void kda_causal_short_conv1d_token_kernel(
  76. float* xTok,
  77. float* state,
  78. const float* w,
  79. int projSize,
  80. int kernel,
  81. int convLen
  82. ) {
  83. int d = blockIdx.x * blockDim.x + threadIdx.x;
  84. if (d >= projSize) {
  85. return;
  86. }
  87. const int wBase = d * kernel;
  88. const int stBase = d * convLen;
  89. // Read input before overwriting xTok.
  90. const float xIn = xTok[d];
  91. float acc = 0.0f;
  92. for (int j = 0; j < convLen; j++) {
  93. acc = fmaf(w[wBase + j], state[stBase + j], acc);
  94. }
  95. acc = fmaf(w[wBase + convLen], xIn, acc);
  96. xTok[d] = silu_f32(acc);
  97. // Update causal state: shift left and append xIn.
  98. if (convLen > 0) {
  99. for (int j = 0; j < convLen - 1; j++) {
  100. state[stBase + j] = state[stBase + j + 1];
  101. }
  102. state[stBase + convLen - 1] = xIn;
  103. }
  104. }
  105. int cuda_kda_causal_short_conv1d_f32(
  106. float* x,
  107. float* state,
  108. const float* w,
  109. int tokens,
  110. int projSize,
  111. int kernel
  112. ) {
  113. if (tokens <= 0 || projSize <= 0) {
  114. return 0;
  115. }
  116. if (kernel <= 1) {
  117. // Just SiLU.
  118. return cuda_silu_f32(x, (size_t)tokens * (size_t)projSize);
  119. }
  120. const int convLen = kernel - 1;
  121. int threads = 256;
  122. int blocks = (projSize + threads - 1) / threads;
  123. for (int t = 0; t < tokens; t++) {
  124. float* xTok = x + (size_t)t * (size_t)projSize;
  125. kda_causal_short_conv1d_token_kernel<<<blocks, threads>>>(xTok, state, w, projSize, kernel, convLen);
  126. CHECK_CUDA(cudaGetLastError());
  127. }
  128. return 0;
  129. }
  130. // ============================================================
  131. // KDA: L2 Norm Heads
  132. // ============================================================
  133. __global__ void kda_l2norm_head_kernel(float* x, int headDim, float eps) {
  134. // One block per head segment
  135. extern __shared__ float sdata[];
  136. int tid = threadIdx.x;
  137. float* head = x + blockIdx.x * headDim;
  138. // Compute sum of squares
  139. float sum = 0.0f;
  140. for (int i = tid; i < headDim; i += blockDim.x) {
  141. float v = head[i];
  142. sum += v * v;
  143. }
  144. sdata[tid] = sum;
  145. __syncthreads();
  146. // Reduce
  147. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  148. if (tid < s) {
  149. sdata[tid] += sdata[tid + s];
  150. }
  151. __syncthreads();
  152. }
  153. float invNorm = rsqrtf(sdata[0] + eps);
  154. // Normalize
  155. for (int i = tid; i < headDim; i += blockDim.x) {
  156. head[i] *= invNorm;
  157. }
  158. }
  159. int cuda_l2norm_heads_f32(float* q, float* k, int tokens, int numHeads, int headDim, float eps) {
  160. if (tokens <= 0 || numHeads <= 0 || headDim <= 0) return 0;
  161. int totalHeads = tokens * numHeads;
  162. int threads = min(256, headDim);
  163. size_t sharedMem = threads * sizeof(float);
  164. kda_l2norm_head_kernel<<<totalHeads, threads, sharedMem>>>(q, headDim, eps);
  165. CHECK_CUDA(cudaGetLastError());
  166. kda_l2norm_head_kernel<<<totalHeads, threads, sharedMem>>>(k, headDim, eps);
  167. CHECK_CUDA(cudaGetLastError());
  168. return 0;
  169. }
  170. // ============================================================
  171. // KDA: Gate computation
  172. // g_out = -exp(aLog[h]) * softplus(g + dtBias)
  173. // ============================================================
  174. __device__ __forceinline__ float softplus_f32(float x) {
  175. return (x > 20.0f) ? x : logf(1.0f + __expf(x));
  176. }
  177. __global__ void kda_gate_kernel(
  178. const float* g,
  179. const float* aLog,
  180. const float* dtBias,
  181. float* out,
  182. int numHeads,
  183. int headDim
  184. ) {
  185. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  186. int projSize = numHeads * headDim;
  187. if (idx >= projSize) return;
  188. int h = idx / headDim;
  189. float mul = -__expf(aLog[h]);
  190. float x = g[idx];
  191. if (dtBias != nullptr) {
  192. x += dtBias[idx];
  193. }
  194. out[idx] = mul * softplus_f32(x);
  195. }
  196. int cuda_kda_gate_f32(
  197. const float* g,
  198. const float* aLog,
  199. const float* dtBias,
  200. float* out,
  201. int tokens,
  202. int numHeads,
  203. int headDim
  204. ) {
  205. if (tokens <= 0) return 0;
  206. int projSize = numHeads * headDim;
  207. int threads = 256;
  208. int blocks = (projSize + threads - 1) / threads;
  209. for (int t = 0; t < tokens; t++) {
  210. const float* gTok = g + t * projSize;
  211. float* outTok = out + t * projSize;
  212. kda_gate_kernel<<<blocks, threads>>>(gTok, aLog, dtBias, outTok, numHeads, headDim);
  213. CHECK_CUDA(cudaGetLastError());
  214. }
  215. return 0;
  216. }
  217. // ============================================================
  218. // KDA: Recurrent (per-token, per-head)
  219. // state[h]: [headDim, headDim]
  220. // ============================================================
  221. __global__ void kda_recurrent_step_kernel(
  222. const float* qTok,
  223. const float* kTok,
  224. float* vTok,
  225. const float* gTok,
  226. const float* betaTok,
  227. float* state,
  228. int numHeads,
  229. int headDim,
  230. float scale
  231. ) {
  232. // One block per head (blockIdx.x), threads work on headDim elements.
  233. extern __shared__ float shared[];
  234. float* tmpKV = shared;
  235. float* tmpVM = shared + headDim;
  236. int h = blockIdx.x;
  237. if (h >= numHeads) return;
  238. int tid = threadIdx.x;
  239. int stateStride = headDim * headDim;
  240. int off = h * headDim;
  241. const float* q = qTok + off;
  242. const float* k = kTok + off;
  243. float* v = vTok + off;
  244. const float* g = gTok + off;
  245. float beta = betaTok[h];
  246. float* st = state + h * stateStride;
  247. // Step 1: Decay state by exp(g)
  248. for (int kk = tid; kk < headDim; kk += blockDim.x) {
  249. float dec = __expf(g[kk]);
  250. for (int vv = 0; vv < headDim; vv++) {
  251. st[kk * headDim + vv] *= dec;
  252. }
  253. }
  254. __syncthreads();
  255. // Step 2: tmpKV = k^T @ state (for each v dimension)
  256. for (int vv = tid; vv < headDim; vv += blockDim.x) {
  257. float acc = 0.0f;
  258. for (int kk = 0; kk < headDim; kk++) {
  259. acc += k[kk] * st[kk * headDim + vv];
  260. }
  261. tmpKV[vv] = acc;
  262. }
  263. __syncthreads();
  264. // Step 3: tmpVM = v - tmpKV
  265. for (int vv = tid; vv < headDim; vv += blockDim.x) {
  266. tmpVM[vv] = v[vv] - tmpKV[vv];
  267. }
  268. __syncthreads();
  269. // Step 4: state += beta * k @ tmpVM^T
  270. for (int kk = tid; kk < headDim; kk += blockDim.x) {
  271. float kj = beta * k[kk];
  272. for (int vv = 0; vv < headDim; vv++) {
  273. st[kk * headDim + vv] += kj * tmpVM[vv];
  274. }
  275. }
  276. __syncthreads();
  277. // Step 5: v = (q * scale)^T @ state
  278. for (int vv = tid; vv < headDim; vv += blockDim.x) {
  279. float acc = 0.0f;
  280. for (int kk = 0; kk < headDim; kk++) {
  281. acc += (q[kk] * scale) * st[kk * headDim + vv];
  282. }
  283. v[vv] = acc;
  284. }
  285. }
  286. int cuda_kda_recurrent_f32(
  287. const float* q,
  288. const float* k,
  289. float* v,
  290. const float* g,
  291. const float* beta,
  292. float* state,
  293. int tokens,
  294. int numHeads,
  295. int headDim
  296. ) {
  297. if (tokens <= 0 || numHeads <= 0 || headDim <= 0) return 0;
  298. int projSize = numHeads * headDim;
  299. float scale = 1.0f / sqrtf((float)headDim);
  300. int threads = min(256, headDim);
  301. size_t sharedMem = 2 * headDim * sizeof(float);
  302. for (int t = 0; t < tokens; t++) {
  303. const float* qTok = q + t * projSize;
  304. const float* kTok = k + t * projSize;
  305. float* vTok = v + t * projSize;
  306. const float* gTok = g + t * projSize;
  307. const float* betaTok = beta + t * numHeads;
  308. kda_recurrent_step_kernel<<<numHeads, threads, sharedMem>>>(
  309. qTok, kTok, vTok, gTok, betaTok, state, numHeads, headDim, scale
  310. );
  311. CHECK_CUDA(cudaGetLastError());
  312. }
  313. return 0;
  314. }
  315. // ============================================================
  316. // KDA: RMSNorm Gated
  317. // out = (out / rms) * weight * sigmoid(g)
  318. // ============================================================
  319. __global__ void kda_rmsnorm_gated_kernel(
  320. float* out,
  321. const float* g,
  322. const float* weight,
  323. int headDim,
  324. float eps
  325. ) {
  326. extern __shared__ float sdata[];
  327. int tid = threadIdx.x;
  328. float* head = out + blockIdx.x * headDim;
  329. const float* gHead = g ? (g + blockIdx.x * headDim) : nullptr;
  330. // Compute sum of squares
  331. float sum = 0.0f;
  332. for (int i = tid; i < headDim; i += blockDim.x) {
  333. float v = head[i];
  334. sum += v * v;
  335. }
  336. sdata[tid] = sum;
  337. __syncthreads();
  338. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  339. if (tid < s) {
  340. sdata[tid] += sdata[tid + s];
  341. }
  342. __syncthreads();
  343. }
  344. float inv = rsqrtf(sdata[0] / (float)headDim + eps);
  345. for (int i = tid; i < headDim; i += blockDim.x) {
  346. float y = head[i] * inv * weight[i];
  347. if (gHead != nullptr) {
  348. y *= 1.0f / (1.0f + __expf(-gHead[i])); // sigmoid
  349. }
  350. head[i] = y;
  351. }
  352. }
  353. int cuda_rmsnorm_gated_f32(
  354. float* out,
  355. const float* g,
  356. const float* weight,
  357. int n,
  358. int headDim,
  359. float eps
  360. ) {
  361. if (n <= 0 || headDim <= 0) return 0;
  362. int numHeads = n / headDim;
  363. int threads = min(256, headDim);
  364. size_t sharedMem = threads * sizeof(float);
  365. kda_rmsnorm_gated_kernel<<<numHeads, threads, sharedMem>>>(out, g, weight, headDim, eps);
  366. CHECK_CUDA(cudaGetLastError());
  367. return 0;
  368. }
  369. // ============================================================
  370. // Sigmoid (for MoE router, etc.)
  371. // ============================================================
  372. __global__ void sigmoid_kernel(float* x, int n) {
  373. int idx = blockIdx.x * blockDim.x + threadIdx.x;
  374. if (idx < n) {
  375. x[idx] = 1.0f / (1.0f + __expf(-x[idx]));
  376. }
  377. }
  378. int cuda_sigmoid_f32(float* x, int n) {
  379. if (n <= 0) return 0;
  380. int threads = 256;
  381. int blocks = (n + threads - 1) / threads;
  382. sigmoid_kernel<<<blocks, threads>>>(x, n);
  383. CHECK_CUDA(cudaGetLastError());
  384. return 0;
  385. }
  386. // ============================================================
  387. // Softmax per row (for MoE router)
  388. // ============================================================
  389. __global__ void softmax_row_kernel(float* x, int cols) {
  390. extern __shared__ float sdata[];
  391. int row = blockIdx.x;
  392. int tid = threadIdx.x;
  393. float* rowData = x + row * cols;
  394. // Find max
  395. float maxVal = -1e30f;
  396. for (int i = tid; i < cols; i += blockDim.x) {
  397. maxVal = fmaxf(maxVal, rowData[i]);
  398. }
  399. sdata[tid] = maxVal;
  400. __syncthreads();
  401. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  402. if (tid < s) sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
  403. __syncthreads();
  404. }
  405. maxVal = sdata[0];
  406. __syncthreads();
  407. // Compute exp and sum
  408. float sum = 0.0f;
  409. for (int i = tid; i < cols; i += blockDim.x) {
  410. float v = __expf(rowData[i] - maxVal);
  411. rowData[i] = v;
  412. sum += v;
  413. }
  414. sdata[tid] = sum;
  415. __syncthreads();
  416. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  417. if (tid < s) sdata[tid] += sdata[tid + s];
  418. __syncthreads();
  419. }
  420. float invSum = 1.0f / sdata[0];
  421. // Normalize
  422. for (int i = tid; i < cols; i += blockDim.x) {
  423. rowData[i] *= invSum;
  424. }
  425. }
  426. int cuda_softmax_rows_f32(float* x, int rows, int cols) {
  427. if (rows <= 0 || cols <= 0) return 0;
  428. int threads = min(256, cols);
  429. size_t sharedMem = threads * sizeof(float);
  430. softmax_row_kernel<<<rows, threads, sharedMem>>>(x, cols);
  431. CHECK_CUDA(cudaGetLastError());
  432. return 0;
  433. }
  434. // ============================================================
  435. // TopK per row (for MoE expert selection)
  436. // ============================================================
  437. __global__ void topk_per_row_kernel(
  438. const float* scores,
  439. int* indices,
  440. float* values,
  441. int cols,
  442. int k
  443. ) {
  444. int row = blockIdx.x;
  445. const float* rowScores = scores + row * cols;
  446. int* rowIndices = indices + row * k;
  447. float* rowValues = values + row * k;
  448. // Simple O(n*k) selection - good enough for small k
  449. for (int i = 0; i < k; i++) {
  450. float bestVal = -1e30f;
  451. int bestIdx = -1;
  452. for (int j = 0; j < cols; j++) {
  453. float v = rowScores[j];
  454. // Check if already selected
  455. bool selected = false;
  456. for (int p = 0; p < i; p++) {
  457. if (rowIndices[p] == j) { selected = true; break; }
  458. }
  459. if (!selected && v > bestVal) {
  460. bestVal = v;
  461. bestIdx = j;
  462. }
  463. }
  464. rowIndices[i] = bestIdx;
  465. rowValues[i] = bestVal;
  466. }
  467. }
  468. int cuda_topk_per_row_f32(
  469. const float* scores,
  470. int* indices,
  471. float* values,
  472. int rows,
  473. int cols,
  474. int k
  475. ) {
  476. if (rows <= 0 || cols <= 0 || k <= 0) return 0;
  477. topk_per_row_kernel<<<rows, 1>>>(scores, indices, values, cols, k);
  478. CHECK_CUDA(cudaGetLastError());
  479. return 0;
  480. }