cuda_nn.cu 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165
  1. #include "cuda_common.cuh"
  2. #include <cuda_fp16.h>
  3. #include <stdint.h>
  4. #include <string.h>
  5. namespace {
  6. constexpr int kPagedAttentionSplitSize = 1024;
  7. constexpr int kPagedAttentionSplitQHThreshold = 4096; // queryCount*numHeads threshold
  8. } // namespace
  9. // ============================================================
  10. // Fused RoPE helpers (constant step table + fast complex pow)
  11. // ============================================================
  12. // Stores cos/sin of per-dimension "step" angle (invFreq) for RoPE.
  13. // Only indices [0, headDim/2) are used. Max supported headDim is 256.
  14. __device__ __constant__ float rope_cos_step_const[128];
  15. __device__ __constant__ float rope_sin_step_const[128];
  16. static int g_rope_step_inited[32] = {0};
  17. static int g_rope_step_head_dim[32] = {0};
  18. static uint32_t g_rope_step_theta_bits[32] = {0};
  19. static int ensure_rope_step_table(int headDim, float theta) {
  20. if (headDim <= 0 || headDim > 256 || (headDim & 1) != 0) {
  21. return 1;
  22. }
  23. int dev = 0;
  24. CHECK_CUDA(cudaGetDevice(&dev));
  25. if (dev < 0 || dev >= 32) {
  26. dev = 0;
  27. }
  28. uint32_t thetaBits = 0;
  29. memcpy(&thetaBits, &theta, sizeof(thetaBits));
  30. if (g_rope_step_inited[dev] && g_rope_step_head_dim[dev] == headDim && g_rope_step_theta_bits[dev] == thetaBits) {
  31. return 0;
  32. }
  33. float cosStep[128];
  34. float sinStep[128];
  35. const int halfDim = headDim / 2;
  36. for (int j = 0; j < 128; j++) {
  37. if (j < halfDim) {
  38. // invFreq = theta^(-2j/headDim)
  39. const double exp = -2.0 * (double)j / (double)headDim;
  40. const double invFreq = pow((double)theta, exp);
  41. cosStep[j] = (float)cos(invFreq);
  42. sinStep[j] = (float)sin(invFreq);
  43. } else {
  44. cosStep[j] = 1.0f;
  45. sinStep[j] = 0.0f;
  46. }
  47. }
  48. CHECK_CUDA(cudaMemcpyToSymbol(rope_cos_step_const, cosStep, sizeof(cosStep), 0, cudaMemcpyHostToDevice));
  49. CHECK_CUDA(cudaMemcpyToSymbol(rope_sin_step_const, sinStep, sizeof(sinStep), 0, cudaMemcpyHostToDevice));
  50. g_rope_step_inited[dev] = 1;
  51. g_rope_step_head_dim[dev] = headDim;
  52. g_rope_step_theta_bits[dev] = thetaBits;
  53. return 0;
  54. }
  55. __device__ __forceinline__ float2 complex_mul_f2(float2 a, float2 b) {
  56. // (a.x + i a.y) * (b.x + i b.y)
  57. return make_float2(
  58. fmaf(a.x, b.x, -a.y * b.y),
  59. fmaf(a.x, b.y, a.y * b.x)
  60. );
  61. }
  62. __device__ __forceinline__ float2 complex_pow_int(float2 base, int exp) {
  63. float2 result = make_float2(1.0f, 0.0f);
  64. float2 b = base;
  65. int e = exp;
  66. while (e > 0) {
  67. if (e & 1) {
  68. result = complex_mul_f2(result, b);
  69. }
  70. b = complex_mul_f2(b, b);
  71. e >>= 1;
  72. }
  73. return result;
  74. }
  75. __device__ __forceinline__ void rope_advance_neg(float& cosv, float& sinv, float cosStep, float sinStep) {
  76. // Multiply by exp(-i*step): (cos + i sin) * (cosStep - i sinStep)
  77. const float c = cosv;
  78. const float s = sinv;
  79. cosv = fmaf(c, cosStep, s * sinStep);
  80. sinv = fmaf(s, cosStep, -c * sinStep);
  81. }
  82. // ============================================================
  83. // Neural Network Operations
  84. // ============================================================
  85. // RMSNorm kernel: one block per row
  86. __global__ void rmsnorm_kernel(float* x, const float* w, int dim, float eps) {
  87. int row = blockIdx.x;
  88. float* rowData = x + row * dim;
  89. float sum = 0.0f;
  90. for (int i = threadIdx.x; i < dim; i += blockDim.x) {
  91. float v = rowData[i];
  92. sum = fmaf(v, v, sum);
  93. }
  94. // Warp reduce
  95. for (int offset = 16; offset > 0; offset >>= 1) {
  96. sum += __shfl_down_sync(0xffffffff, sum, offset);
  97. }
  98. __shared__ float warpSum[8];
  99. __shared__ float rms;
  100. int lane = threadIdx.x & 31;
  101. int warp = threadIdx.x >> 5;
  102. if (lane == 0) {
  103. warpSum[warp] = sum;
  104. }
  105. __syncthreads();
  106. if (warp == 0) {
  107. float v = (lane < 8) ? warpSum[lane] : 0.0f;
  108. for (int offset = 16; offset > 0; offset >>= 1) {
  109. v += __shfl_down_sync(0xffffffff, v, offset);
  110. }
  111. if (lane == 0) {
  112. rms = rsqrtf(v / dim + eps);
  113. }
  114. }
  115. __syncthreads();
  116. for (int i = threadIdx.x; i < dim; i += blockDim.x) {
  117. rowData[i] = rowData[i] * rms * w[i];
  118. }
  119. }
  120. __global__ void paged_attention_batch_kernel_f16kv(
  121. const float* Q,
  122. const unsigned short* const* KBlocksFlat,
  123. const unsigned short* const* VBlocksFlat,
  124. const int* blockOffsets,
  125. const int* kvLens,
  126. const int* queryPos,
  127. float* out,
  128. int numTokens,
  129. int numHeads, int numKVHeads, int headDim,
  130. int blockSize,
  131. float scale
  132. ) {
  133. const int tok = blockIdx.x;
  134. const int head = blockIdx.y;
  135. const int lane = threadIdx.x & 31;
  136. if (tok >= numTokens) {
  137. return;
  138. }
  139. // One warp per (tok, head)
  140. if (threadIdx.x >= 32) {
  141. return;
  142. }
  143. const int kvHead = head / (numHeads / numKVHeads);
  144. const float* q = Q + tok * numHeads * headDim + head * headDim;
  145. float* o = out + tok * numHeads * headDim + head * headDim;
  146. const int kvLen = kvLens[tok];
  147. const int qPos = queryPos[tok];
  148. const int base = blockOffsets[tok];
  149. const int kvStride = numKVHeads * headDim;
  150. const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
  151. // Cache Q in registers (per lane) to avoid reloading it for every KV token.
  152. float qreg[8];
  153. #pragma unroll
  154. for (int i = 0; i < 8; i++) {
  155. const int d = lane + 32 * i;
  156. qreg[i] = (d < headDim) ? q[d] : 0.0f;
  157. }
  158. // Support headDim up to 256 (<= 8 values per lane)
  159. float acc[8];
  160. #pragma unroll
  161. for (int i = 0; i < 8; i++) {
  162. acc[i] = 0.0f;
  163. }
  164. float m = -INFINITY;
  165. float l = 0.0f;
  166. for (int kv = 0; kv < effectiveLen; kv++) {
  167. const int bidx = kv / blockSize;
  168. const int boff = kv % blockSize;
  169. const __half* kBlock = reinterpret_cast<const __half*>(KBlocksFlat[base + bidx]);
  170. const __half* k = kBlock + boff * kvStride + kvHead * headDim;
  171. float dot = 0.0f;
  172. #pragma unroll
  173. for (int i = 0; i < 8; i++) {
  174. const int d = lane + 32 * i;
  175. if (d < headDim) {
  176. dot = fmaf(qreg[i], __half2float(k[d]), dot);
  177. }
  178. }
  179. for (int offset = 16; offset > 0; offset >>= 1) {
  180. dot += __shfl_down_sync(0xffffffff, dot, offset);
  181. }
  182. dot = __shfl_sync(0xffffffff, dot, 0);
  183. const float score = dot * scale;
  184. float alpha = 1.0f;
  185. float beta = 0.0f;
  186. if (lane == 0) {
  187. const float newM = fmaxf(m, score);
  188. alpha = __expf(m - newM);
  189. beta = __expf(score - newM);
  190. m = newM;
  191. l = l * alpha + beta;
  192. }
  193. alpha = __shfl_sync(0xffffffff, alpha, 0);
  194. beta = __shfl_sync(0xffffffff, beta, 0);
  195. l = __shfl_sync(0xffffffff, l, 0);
  196. const __half* vBlock = reinterpret_cast<const __half*>(VBlocksFlat[base + bidx]);
  197. const __half* v = vBlock + boff * kvStride + kvHead * headDim;
  198. #pragma unroll
  199. for (int i = 0; i < 8; i++) {
  200. const int d = lane + 32 * i;
  201. if (d < headDim) {
  202. acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
  203. }
  204. }
  205. }
  206. const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
  207. #pragma unroll
  208. for (int i = 0; i < 8; i++) {
  209. const int d = lane + 32 * i;
  210. if (d < headDim) {
  211. o[d] = acc[i] * invL;
  212. }
  213. }
  214. }
  215. // Fused RoPE + paged attention (batch, f16 KV). Expects un-rotated Q/K.
  216. __global__ void paged_attention_batch_kernel_f16kv_rope(
  217. const float* Q,
  218. const unsigned short* const* KBlocksFlat,
  219. const unsigned short* const* VBlocksFlat,
  220. const int* blockOffsets,
  221. const int* kvLens,
  222. const int* queryPos,
  223. float* out,
  224. int numTokens,
  225. int numHeads, int numKVHeads, int headDim,
  226. int blockSize,
  227. float scale
  228. ) {
  229. const int tok = blockIdx.x;
  230. const int head = blockIdx.y;
  231. const int lane = threadIdx.x & 31;
  232. if (tok >= numTokens) {
  233. return;
  234. }
  235. // One warp per (tok, head)
  236. if (threadIdx.x >= 32) {
  237. return;
  238. }
  239. const int kvHead = head / (numHeads / numKVHeads);
  240. const float* q = Q + tok * numHeads * headDim + head * headDim;
  241. float* o = out + tok * numHeads * headDim + head * headDim;
  242. const int kvLen = kvLens[tok];
  243. const int qPos = queryPos[tok];
  244. const int base = blockOffsets[tok];
  245. const int kvStride = numKVHeads * headDim;
  246. const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
  247. const int halfDim = headDim >> 1;
  248. // Cache Q pairs + per-dim RoPE phase for delta=(qPos - kv) with kv starting at 0.
  249. float q0[4];
  250. float q1[4];
  251. float cosStep[4];
  252. float sinStep[4];
  253. float cosDelta[4];
  254. float sinDelta[4];
  255. int pairCount = 0;
  256. for (int j = lane; j < halfDim; j += 32) {
  257. q0[pairCount] = q[j];
  258. q1[pairCount] = q[j + halfDim];
  259. cosStep[pairCount] = rope_cos_step_const[j];
  260. sinStep[pairCount] = rope_sin_step_const[j];
  261. const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
  262. const float2 ph = complex_pow_int(baseStep, qPos);
  263. cosDelta[pairCount] = ph.x;
  264. sinDelta[pairCount] = ph.y;
  265. pairCount++;
  266. }
  267. // Support headDim up to 256.
  268. float acc[8];
  269. #pragma unroll
  270. for (int i = 0; i < 8; i++) {
  271. acc[i] = 0.0f;
  272. }
  273. float m = -INFINITY;
  274. float l = 0.0f;
  275. for (int kv = 0; kv < effectiveLen; kv++) {
  276. const int bidx = kv / blockSize;
  277. const int boff = kv % blockSize;
  278. const __half* kBlock = reinterpret_cast<const __half*>(KBlocksFlat[base + bidx]);
  279. const __half* k = kBlock + boff * kvStride + kvHead * headDim;
  280. float dot = 0.0f;
  281. #pragma unroll
  282. for (int pi = 0; pi < 4; pi++) {
  283. if (pi >= pairCount) {
  284. break;
  285. }
  286. const int j = lane + 32 * pi;
  287. if (j >= halfDim) {
  288. continue;
  289. }
  290. const float k0 = __half2float(k[j]);
  291. const float k1 = __half2float(k[j + halfDim]);
  292. const float a = fmaf(q0[pi], k0, q1[pi] * k1); // q0*k0 + q1*k1
  293. const float b = fmaf(q0[pi], k1, -q1[pi] * k0); // q0*k1 - q1*k0
  294. dot = fmaf(cosDelta[pi], a, dot);
  295. dot = fmaf(sinDelta[pi], b, dot);
  296. }
  297. for (int offset = 16; offset > 0; offset >>= 1) {
  298. dot += __shfl_down_sync(0xffffffff, dot, offset);
  299. }
  300. dot = __shfl_sync(0xffffffff, dot, 0);
  301. // Advance delta -> delta-1 for next kv.
  302. #pragma unroll
  303. for (int pi = 0; pi < 4; pi++) {
  304. if (pi >= pairCount) {
  305. break;
  306. }
  307. rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
  308. }
  309. const float score = dot * scale;
  310. float alpha = 1.0f;
  311. float beta = 0.0f;
  312. if (lane == 0) {
  313. const float newM = fmaxf(m, score);
  314. alpha = __expf(m - newM);
  315. beta = __expf(score - newM);
  316. m = newM;
  317. l = l * alpha + beta;
  318. }
  319. alpha = __shfl_sync(0xffffffff, alpha, 0);
  320. beta = __shfl_sync(0xffffffff, beta, 0);
  321. l = __shfl_sync(0xffffffff, l, 0);
  322. const __half* vBlock = reinterpret_cast<const __half*>(VBlocksFlat[base + bidx]);
  323. const __half* v = vBlock + boff * kvStride + kvHead * headDim;
  324. #pragma unroll
  325. for (int i = 0; i < 8; i++) {
  326. const int d = lane + 32 * i;
  327. if (d < headDim) {
  328. acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
  329. }
  330. }
  331. }
  332. const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
  333. #pragma unroll
  334. for (int i = 0; i < 8; i++) {
  335. const int d = lane + 32 * i;
  336. if (d < headDim) {
  337. o[d] = acc[i] * invL;
  338. }
  339. }
  340. }
  341. __global__ void cast_f32_to_f16_kernel(const float* src, __half* dst, int n) {
  342. int i = blockIdx.x * blockDim.x + threadIdx.x;
  343. if (i < n) {
  344. dst[i] = __float2half_rn(src[i]);
  345. }
  346. }
  347. __global__ void paged_attention_kernel_f16kv(
  348. const float* Q,
  349. const unsigned short* const* KBlocks,
  350. const unsigned short* const* VBlocks,
  351. float* out,
  352. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  353. int blockSize,
  354. float scale, int startPos
  355. ) {
  356. const int seq = blockIdx.x;
  357. const int head = blockIdx.y;
  358. const int lane = threadIdx.x & 31;
  359. if (seq >= seqLen) {
  360. return;
  361. }
  362. // One warp per (seq, head)
  363. if (threadIdx.x >= 32) {
  364. return;
  365. }
  366. const int kvHead = head / (numHeads / numKVHeads);
  367. const float* q = Q + seq * numHeads * headDim + head * headDim;
  368. float* o = out + seq * numHeads * headDim + head * headDim;
  369. const int kvStride = numKVHeads * headDim;
  370. const int queryPos = startPos + seq;
  371. const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
  372. // Cache Q in registers (per lane) to avoid reloading it for every KV token.
  373. float qreg[8];
  374. #pragma unroll
  375. for (int i = 0; i < 8; i++) {
  376. const int d = lane + 32 * i;
  377. qreg[i] = (d < headDim) ? q[d] : 0.0f;
  378. }
  379. float acc[8];
  380. #pragma unroll
  381. for (int i = 0; i < 8; i++) {
  382. acc[i] = 0.0f;
  383. }
  384. float m = -INFINITY;
  385. float l = 0.0f;
  386. for (int kv = 0; kv < effectiveLen; kv++) {
  387. const int blockIdxKV = kv / blockSize;
  388. const int blockOff = kv % blockSize;
  389. const __half* kBlock = reinterpret_cast<const __half*>(KBlocks[blockIdxKV]);
  390. const __half* k = kBlock + blockOff * kvStride + kvHead * headDim;
  391. float dot = 0.0f;
  392. #pragma unroll
  393. for (int i = 0; i < 8; i++) {
  394. const int d = lane + 32 * i;
  395. if (d < headDim) {
  396. dot = fmaf(qreg[i], __half2float(k[d]), dot);
  397. }
  398. }
  399. for (int offset = 16; offset > 0; offset >>= 1) {
  400. dot += __shfl_down_sync(0xffffffff, dot, offset);
  401. }
  402. dot = __shfl_sync(0xffffffff, dot, 0);
  403. const float score = dot * scale;
  404. float alpha = 1.0f;
  405. float beta = 0.0f;
  406. if (lane == 0) {
  407. const float newM = fmaxf(m, score);
  408. alpha = __expf(m - newM);
  409. beta = __expf(score - newM);
  410. m = newM;
  411. l = l * alpha + beta;
  412. }
  413. alpha = __shfl_sync(0xffffffff, alpha, 0);
  414. beta = __shfl_sync(0xffffffff, beta, 0);
  415. l = __shfl_sync(0xffffffff, l, 0);
  416. const __half* vBlock = reinterpret_cast<const __half*>(VBlocks[blockIdxKV]);
  417. const __half* v = vBlock + blockOff * kvStride + kvHead * headDim;
  418. #pragma unroll
  419. for (int i = 0; i < 8; i++) {
  420. const int d = lane + 32 * i;
  421. if (d < headDim) {
  422. acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
  423. }
  424. }
  425. }
  426. const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
  427. #pragma unroll
  428. for (int i = 0; i < 8; i++) {
  429. const int d = lane + 32 * i;
  430. if (d < headDim) {
  431. o[d] = acc[i] * invL;
  432. }
  433. }
  434. }
  435. // Fused RoPE + paged attention (single, f16 KV). Expects un-rotated Q/K.
  436. __global__ void paged_attention_kernel_f16kv_rope(
  437. const float* Q,
  438. const unsigned short* const* KBlocks,
  439. const unsigned short* const* VBlocks,
  440. float* out,
  441. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  442. int blockSize,
  443. float scale, int startPos
  444. ) {
  445. const int seq = blockIdx.x;
  446. const int head = blockIdx.y;
  447. const int lane = threadIdx.x & 31;
  448. if (seq >= seqLen) {
  449. return;
  450. }
  451. // One warp per (seq, head)
  452. if (threadIdx.x >= 32) {
  453. return;
  454. }
  455. const int kvHead = head / (numHeads / numKVHeads);
  456. const float* q = Q + seq * numHeads * headDim + head * headDim;
  457. float* o = out + seq * numHeads * headDim + head * headDim;
  458. const int kvStride = numKVHeads * headDim;
  459. const int queryPos = startPos + seq;
  460. const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
  461. const int halfDim = headDim >> 1;
  462. float q0[4];
  463. float q1[4];
  464. float cosStep[4];
  465. float sinStep[4];
  466. float cosDelta[4];
  467. float sinDelta[4];
  468. int pairCount = 0;
  469. for (int j = lane; j < halfDim; j += 32) {
  470. q0[pairCount] = q[j];
  471. q1[pairCount] = q[j + halfDim];
  472. cosStep[pairCount] = rope_cos_step_const[j];
  473. sinStep[pairCount] = rope_sin_step_const[j];
  474. const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
  475. const float2 ph = complex_pow_int(baseStep, queryPos);
  476. cosDelta[pairCount] = ph.x;
  477. sinDelta[pairCount] = ph.y;
  478. pairCount++;
  479. }
  480. float acc[8];
  481. #pragma unroll
  482. for (int i = 0; i < 8; i++) {
  483. acc[i] = 0.0f;
  484. }
  485. float m = -INFINITY;
  486. float l = 0.0f;
  487. for (int kv = 0; kv < effectiveLen; kv++) {
  488. const int blockIdxKV = kv / blockSize;
  489. const int blockOff = kv % blockSize;
  490. const __half* kBlock = reinterpret_cast<const __half*>(KBlocks[blockIdxKV]);
  491. const __half* k = kBlock + blockOff * kvStride + kvHead * headDim;
  492. float dot = 0.0f;
  493. #pragma unroll
  494. for (int pi = 0; pi < 4; pi++) {
  495. if (pi >= pairCount) {
  496. break;
  497. }
  498. const int j = lane + 32 * pi;
  499. if (j >= halfDim) {
  500. continue;
  501. }
  502. const float k0 = __half2float(k[j]);
  503. const float k1 = __half2float(k[j + halfDim]);
  504. const float a = fmaf(q0[pi], k0, q1[pi] * k1);
  505. const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
  506. dot = fmaf(cosDelta[pi], a, dot);
  507. dot = fmaf(sinDelta[pi], b, dot);
  508. }
  509. for (int offset = 16; offset > 0; offset >>= 1) {
  510. dot += __shfl_down_sync(0xffffffff, dot, offset);
  511. }
  512. dot = __shfl_sync(0xffffffff, dot, 0);
  513. // Advance delta -> delta-1 for next kv.
  514. #pragma unroll
  515. for (int pi = 0; pi < 4; pi++) {
  516. if (pi >= pairCount) {
  517. break;
  518. }
  519. rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
  520. }
  521. const float score = dot * scale;
  522. float alpha = 1.0f;
  523. float beta = 0.0f;
  524. if (lane == 0) {
  525. const float newM = fmaxf(m, score);
  526. alpha = __expf(m - newM);
  527. beta = __expf(score - newM);
  528. m = newM;
  529. l = l * alpha + beta;
  530. }
  531. alpha = __shfl_sync(0xffffffff, alpha, 0);
  532. beta = __shfl_sync(0xffffffff, beta, 0);
  533. l = __shfl_sync(0xffffffff, l, 0);
  534. const __half* vBlock = reinterpret_cast<const __half*>(VBlocks[blockIdxKV]);
  535. const __half* v = vBlock + blockOff * kvStride + kvHead * headDim;
  536. #pragma unroll
  537. for (int i = 0; i < 8; i++) {
  538. const int d = lane + 32 * i;
  539. if (d < headDim) {
  540. acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
  541. }
  542. }
  543. }
  544. const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
  545. #pragma unroll
  546. for (int i = 0; i < 8; i++) {
  547. const int d = lane + 32 * i;
  548. if (d < headDim) {
  549. o[d] = acc[i] * invL;
  550. }
  551. }
  552. }
  553. template <typename T>
  554. __device__ __forceinline__ float load_kv(const T* p, int idx) {
  555. return p[idx];
  556. }
  557. template <>
  558. __device__ __forceinline__ float load_kv<__half>(const __half* p, int idx) {
  559. return __half2float(p[idx]);
  560. }
  561. template <typename T, bool kUseRoPE = false>
  562. __global__ void paged_attention_split_kv_kernel(
  563. const float* Q,
  564. const T* const* KBlocks,
  565. const T* const* VBlocks,
  566. float* partialMax,
  567. float* partialSum,
  568. float* partialOut,
  569. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  570. int blockSize,
  571. float scale, int startPos,
  572. int numSplits, int splitSize
  573. ) {
  574. const int seq = blockIdx.x;
  575. const int head = blockIdx.y;
  576. const int split = blockIdx.z;
  577. const int lane = threadIdx.x & 31;
  578. if (seq >= seqLen) {
  579. return;
  580. }
  581. // One warp per (seq, head, split).
  582. if (threadIdx.x >= 32) {
  583. return;
  584. }
  585. const int kvHead = head / (numHeads / numKVHeads);
  586. const float* q = Q + seq * numHeads * headDim + head * headDim;
  587. const int kvStride = numKVHeads * headDim;
  588. const int queryPos = startPos + seq;
  589. const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
  590. const int splitStart = split * splitSize;
  591. const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen;
  592. const size_t qh = (size_t)seq * (size_t)numHeads + (size_t)head;
  593. const size_t splitIdx = qh * (size_t)numSplits + (size_t)split;
  594. float* outVec = partialOut + splitIdx * (size_t)headDim;
  595. if (splitStart >= splitEnd) {
  596. if (lane == 0) {
  597. partialMax[splitIdx] = -INFINITY;
  598. partialSum[splitIdx] = 0.0f;
  599. }
  600. #pragma unroll
  601. for (int i = 0; i < 8; i++) {
  602. const int d = lane + 32 * i;
  603. if (d < headDim) {
  604. outVec[d] = 0.0f;
  605. }
  606. }
  607. return;
  608. }
  609. const int halfDim = headDim >> 1;
  610. const int ropeExp = queryPos - splitStart;
  611. float qreg[8];
  612. float q0[4];
  613. float q1[4];
  614. float cosStep[4];
  615. float sinStep[4];
  616. float cosDelta[4];
  617. float sinDelta[4];
  618. int pairCount = 0;
  619. if (kUseRoPE) {
  620. for (int j = lane; j < halfDim; j += 32) {
  621. q0[pairCount] = q[j];
  622. q1[pairCount] = q[j + halfDim];
  623. cosStep[pairCount] = rope_cos_step_const[j];
  624. sinStep[pairCount] = rope_sin_step_const[j];
  625. const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
  626. const float2 ph = complex_pow_int(baseStep, ropeExp);
  627. cosDelta[pairCount] = ph.x;
  628. sinDelta[pairCount] = ph.y;
  629. pairCount++;
  630. }
  631. } else {
  632. #pragma unroll
  633. for (int i = 0; i < 8; i++) {
  634. const int d = lane + 32 * i;
  635. qreg[i] = (d < headDim) ? q[d] : 0.0f;
  636. }
  637. }
  638. float acc[8];
  639. #pragma unroll
  640. for (int i = 0; i < 8; i++) {
  641. acc[i] = 0.0f;
  642. }
  643. float m = -INFINITY;
  644. float l = 0.0f;
  645. for (int kv = splitStart; kv < splitEnd; kv++) {
  646. const int blockIdxKV = kv / blockSize;
  647. const int blockOff = kv % blockSize;
  648. const T* kBlock = KBlocks[blockIdxKV];
  649. const T* k = kBlock + blockOff * kvStride + kvHead * headDim;
  650. float dot = 0.0f;
  651. if (kUseRoPE) {
  652. #pragma unroll
  653. for (int pi = 0; pi < 4; pi++) {
  654. if (pi >= pairCount) {
  655. break;
  656. }
  657. const int j = lane + 32 * pi;
  658. if (j >= halfDim) {
  659. continue;
  660. }
  661. const float k0 = load_kv<T>(k, j);
  662. const float k1 = load_kv<T>(k, j + halfDim);
  663. const float a = fmaf(q0[pi], k0, q1[pi] * k1);
  664. const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
  665. dot = fmaf(cosDelta[pi], a, dot);
  666. dot = fmaf(sinDelta[pi], b, dot);
  667. }
  668. } else {
  669. #pragma unroll
  670. for (int i = 0; i < 8; i++) {
  671. const int d = lane + 32 * i;
  672. if (d < headDim) {
  673. dot = fmaf(qreg[i], load_kv<T>(k, d), dot);
  674. }
  675. }
  676. }
  677. for (int offset = 16; offset > 0; offset >>= 1) {
  678. dot += __shfl_down_sync(0xffffffff, dot, offset);
  679. }
  680. dot = __shfl_sync(0xffffffff, dot, 0);
  681. if (kUseRoPE) {
  682. #pragma unroll
  683. for (int pi = 0; pi < 4; pi++) {
  684. if (pi >= pairCount) {
  685. break;
  686. }
  687. rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
  688. }
  689. }
  690. const float score = dot * scale;
  691. float alpha = 1.0f;
  692. float beta = 0.0f;
  693. if (lane == 0) {
  694. const float newM = fmaxf(m, score);
  695. alpha = __expf(m - newM);
  696. beta = __expf(score - newM);
  697. m = newM;
  698. l = l * alpha + beta;
  699. }
  700. alpha = __shfl_sync(0xffffffff, alpha, 0);
  701. beta = __shfl_sync(0xffffffff, beta, 0);
  702. const T* vBlock = VBlocks[blockIdxKV];
  703. const T* v = vBlock + blockOff * kvStride + kvHead * headDim;
  704. #pragma unroll
  705. for (int i = 0; i < 8; i++) {
  706. const int d = lane + 32 * i;
  707. if (d < headDim) {
  708. acc[i] = acc[i] * alpha + beta * load_kv<T>(v, d);
  709. }
  710. }
  711. }
  712. if (lane == 0) {
  713. partialMax[splitIdx] = m;
  714. partialSum[splitIdx] = l;
  715. }
  716. #pragma unroll
  717. for (int i = 0; i < 8; i++) {
  718. const int d = lane + 32 * i;
  719. if (d < headDim) {
  720. outVec[d] = acc[i];
  721. }
  722. }
  723. }
  724. template <typename T, bool kUseRoPE = false>
  725. __global__ void paged_attention_split_kv_batch_kernel(
  726. const float* Q,
  727. const T* const* KBlocksFlat,
  728. const T* const* VBlocksFlat,
  729. const int* blockOffsets,
  730. const int* kvLens,
  731. const int* queryPos,
  732. float* partialMax,
  733. float* partialSum,
  734. float* partialOut,
  735. int numTokens,
  736. int numHeads, int numKVHeads, int headDim,
  737. int blockSize,
  738. float scale,
  739. int numSplits, int splitSize
  740. ) {
  741. const int tok = blockIdx.x;
  742. const int head = blockIdx.y;
  743. const int split = blockIdx.z;
  744. const int lane = threadIdx.x & 31;
  745. if (tok >= numTokens) {
  746. return;
  747. }
  748. // One warp per (tok, head, split).
  749. if (threadIdx.x >= 32) {
  750. return;
  751. }
  752. const int kvHead = head / (numHeads / numKVHeads);
  753. const float* q = Q + tok * numHeads * headDim + head * headDim;
  754. const int kvLen = kvLens[tok];
  755. const int qPos = queryPos[tok];
  756. const int base = blockOffsets[tok];
  757. const int kvStride = numKVHeads * headDim;
  758. const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
  759. const int splitStart = split * splitSize;
  760. const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen;
  761. const size_t qh = (size_t)tok * (size_t)numHeads + (size_t)head;
  762. const size_t splitIdx = qh * (size_t)numSplits + (size_t)split;
  763. float* outVec = partialOut + splitIdx * (size_t)headDim;
  764. if (splitStart >= splitEnd) {
  765. if (lane == 0) {
  766. partialMax[splitIdx] = -INFINITY;
  767. partialSum[splitIdx] = 0.0f;
  768. }
  769. #pragma unroll
  770. for (int i = 0; i < 8; i++) {
  771. const int d = lane + 32 * i;
  772. if (d < headDim) {
  773. outVec[d] = 0.0f;
  774. }
  775. }
  776. return;
  777. }
  778. const int halfDim = headDim >> 1;
  779. const int ropeExp = qPos - splitStart;
  780. float qreg[8];
  781. float q0[4];
  782. float q1[4];
  783. float cosStep[4];
  784. float sinStep[4];
  785. float cosDelta[4];
  786. float sinDelta[4];
  787. int pairCount = 0;
  788. if (kUseRoPE) {
  789. for (int j = lane; j < halfDim; j += 32) {
  790. q0[pairCount] = q[j];
  791. q1[pairCount] = q[j + halfDim];
  792. cosStep[pairCount] = rope_cos_step_const[j];
  793. sinStep[pairCount] = rope_sin_step_const[j];
  794. const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
  795. const float2 ph = complex_pow_int(baseStep, ropeExp);
  796. cosDelta[pairCount] = ph.x;
  797. sinDelta[pairCount] = ph.y;
  798. pairCount++;
  799. }
  800. } else {
  801. #pragma unroll
  802. for (int i = 0; i < 8; i++) {
  803. const int d = lane + 32 * i;
  804. qreg[i] = (d < headDim) ? q[d] : 0.0f;
  805. }
  806. }
  807. float acc[8];
  808. #pragma unroll
  809. for (int i = 0; i < 8; i++) {
  810. acc[i] = 0.0f;
  811. }
  812. float m = -INFINITY;
  813. float l = 0.0f;
  814. for (int kv = splitStart; kv < splitEnd; kv++) {
  815. const int bidx = kv / blockSize;
  816. const int boff = kv % blockSize;
  817. const T* kBlock = KBlocksFlat[base + bidx];
  818. const T* k = kBlock + boff * kvStride + kvHead * headDim;
  819. float dot = 0.0f;
  820. if (kUseRoPE) {
  821. #pragma unroll
  822. for (int pi = 0; pi < 4; pi++) {
  823. if (pi >= pairCount) {
  824. break;
  825. }
  826. const int j = lane + 32 * pi;
  827. if (j >= halfDim) {
  828. continue;
  829. }
  830. const float k0 = load_kv<T>(k, j);
  831. const float k1 = load_kv<T>(k, j + halfDim);
  832. const float a = fmaf(q0[pi], k0, q1[pi] * k1);
  833. const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
  834. dot = fmaf(cosDelta[pi], a, dot);
  835. dot = fmaf(sinDelta[pi], b, dot);
  836. }
  837. } else {
  838. #pragma unroll
  839. for (int i = 0; i < 8; i++) {
  840. const int d = lane + 32 * i;
  841. if (d < headDim) {
  842. dot = fmaf(qreg[i], load_kv<T>(k, d), dot);
  843. }
  844. }
  845. }
  846. for (int offset = 16; offset > 0; offset >>= 1) {
  847. dot += __shfl_down_sync(0xffffffff, dot, offset);
  848. }
  849. dot = __shfl_sync(0xffffffff, dot, 0);
  850. if (kUseRoPE) {
  851. #pragma unroll
  852. for (int pi = 0; pi < 4; pi++) {
  853. if (pi >= pairCount) {
  854. break;
  855. }
  856. rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
  857. }
  858. }
  859. const float score = dot * scale;
  860. float alpha = 1.0f;
  861. float beta = 0.0f;
  862. if (lane == 0) {
  863. const float newM = fmaxf(m, score);
  864. alpha = __expf(m - newM);
  865. beta = __expf(score - newM);
  866. m = newM;
  867. l = l * alpha + beta;
  868. }
  869. alpha = __shfl_sync(0xffffffff, alpha, 0);
  870. beta = __shfl_sync(0xffffffff, beta, 0);
  871. const T* vBlock = VBlocksFlat[base + bidx];
  872. const T* v = vBlock + boff * kvStride + kvHead * headDim;
  873. #pragma unroll
  874. for (int i = 0; i < 8; i++) {
  875. const int d = lane + 32 * i;
  876. if (d < headDim) {
  877. acc[i] = acc[i] * alpha + beta * load_kv<T>(v, d);
  878. }
  879. }
  880. }
  881. if (lane == 0) {
  882. partialMax[splitIdx] = m;
  883. partialSum[splitIdx] = l;
  884. }
  885. #pragma unroll
  886. for (int i = 0; i < 8; i++) {
  887. const int d = lane + 32 * i;
  888. if (d < headDim) {
  889. outVec[d] = acc[i];
  890. }
  891. }
  892. }
  893. __global__ void paged_attention_split_kv_reduce_kernel(
  894. const float* partialMax,
  895. const float* partialSum,
  896. const float* partialOut,
  897. float* out,
  898. int queryCount,
  899. int numHeads,
  900. int headDim,
  901. int numSplits
  902. ) {
  903. const int q = blockIdx.x;
  904. const int head = blockIdx.y;
  905. const int lane = threadIdx.x & 31;
  906. if (q >= queryCount) {
  907. return;
  908. }
  909. // One warp per (q, head).
  910. if (threadIdx.x >= 32) {
  911. return;
  912. }
  913. const size_t qh = (size_t)q * (size_t)numHeads + (size_t)head;
  914. const size_t base = qh * (size_t)numSplits;
  915. float gmax = -INFINITY;
  916. if (lane == 0) {
  917. for (int s = 0; s < numSplits; s++) {
  918. gmax = fmaxf(gmax, partialMax[base + (size_t)s]);
  919. }
  920. }
  921. gmax = __shfl_sync(0xffffffff, gmax, 0);
  922. if (gmax == -INFINITY) {
  923. const size_t outBase = qh * (size_t)headDim;
  924. #pragma unroll
  925. for (int i = 0; i < 8; i++) {
  926. const int d = lane + 32 * i;
  927. if (d < headDim) {
  928. out[outBase + (size_t)d] = 0.0f;
  929. }
  930. }
  931. return;
  932. }
  933. float sum = 0.0f;
  934. if (lane == 0) {
  935. for (int s = 0; s < numSplits; s++) {
  936. const float l = partialSum[base + (size_t)s];
  937. sum += l * __expf(partialMax[base + (size_t)s] - gmax);
  938. }
  939. }
  940. sum = __shfl_sync(0xffffffff, sum, 0);
  941. const float invSum = (sum > 0.0f) ? (1.0f / sum) : 0.0f;
  942. float acc[8];
  943. #pragma unroll
  944. for (int i = 0; i < 8; i++) {
  945. acc[i] = 0.0f;
  946. }
  947. for (int s = 0; s < numSplits; s++) {
  948. float scale = 0.0f;
  949. if (lane == 0) {
  950. scale = __expf(partialMax[base + (size_t)s] - gmax);
  951. }
  952. scale = __shfl_sync(0xffffffff, scale, 0);
  953. const float* vec = partialOut + (base + (size_t)s) * (size_t)headDim;
  954. #pragma unroll
  955. for (int i = 0; i < 8; i++) {
  956. const int d = lane + 32 * i;
  957. if (d < headDim) {
  958. acc[i] = fmaf(scale, vec[d], acc[i]);
  959. }
  960. }
  961. }
  962. const size_t outBase = qh * (size_t)headDim;
  963. #pragma unroll
  964. for (int i = 0; i < 8; i++) {
  965. const int d = lane + 32 * i;
  966. if (d < headDim) {
  967. out[outBase + (size_t)d] = acc[i] * invSum;
  968. }
  969. }
  970. }
  971. int cuda_rmsnorm_f32(float* x, const float* w, int seqLen, int dim, float eps) {
  972. int threads = 256;
  973. rmsnorm_kernel<<<seqLen, threads>>>(x, w, dim, eps);
  974. CHECK_CUDA(cudaGetLastError());
  975. return 0;
  976. }
  977. int cuda_cast_f32_to_f16(const float* src, unsigned short* dst, int n) {
  978. int threads = 256;
  979. int blocks = (n + threads - 1) / threads;
  980. cast_f32_to_f16_kernel<<<blocks, threads>>>(src, reinterpret_cast<__half*>(dst), n);
  981. CHECK_CUDA(cudaGetLastError());
  982. return 0;
  983. }
  984. int cuda_paged_attention_f32_f16kv(
  985. const float* Q,
  986. const unsigned short* const* KBlocks,
  987. const unsigned short* const* VBlocks,
  988. float* out,
  989. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  990. int blockSize,
  991. float scale, int startPos
  992. ) {
  993. // Split-KV Flash Decoding for long contexts.
  994. const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
  995. const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
  996. const int qhCount = seqLen * numHeads;
  997. const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
  998. if (!useSplit) {
  999. dim3 blocks(seqLen, numHeads);
  1000. int threads = 32;
  1001. paged_attention_kernel_f16kv<<<blocks, threads>>>(
  1002. Q, KBlocks, VBlocks, out,
  1003. seqLen, kvLen, numHeads, numKVHeads, headDim,
  1004. blockSize, scale, startPos
  1005. );
  1006. CHECK_CUDA(cudaGetLastError());
  1007. return 0;
  1008. }
  1009. const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
  1010. const size_t totalFloats = splitCount * (size_t)(headDim + 2);
  1011. float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
  1012. if (buf == nullptr) {
  1013. return 1;
  1014. }
  1015. float* partialMax = buf;
  1016. float* partialSum = partialMax + splitCount;
  1017. float* partialOut = partialSum + splitCount;
  1018. dim3 blocks1(seqLen, numHeads, numSplits);
  1019. paged_attention_split_kv_kernel<__half><<<blocks1, 32>>>(
  1020. Q,
  1021. reinterpret_cast<const __half* const*>(KBlocks),
  1022. reinterpret_cast<const __half* const*>(VBlocks),
  1023. partialMax, partialSum, partialOut,
  1024. seqLen, kvLen, numHeads, numKVHeads, headDim,
  1025. blockSize,
  1026. scale, startPos,
  1027. numSplits, kPagedAttentionSplitSize
  1028. );
  1029. CHECK_CUDA(cudaGetLastError());
  1030. dim3 blocks2(seqLen, numHeads);
  1031. paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
  1032. partialMax, partialSum, partialOut,
  1033. out,
  1034. seqLen, numHeads, headDim, numSplits
  1035. );
  1036. CHECK_CUDA(cudaGetLastError());
  1037. cuda_free(buf);
  1038. return 0;
  1039. }
  1040. int cuda_paged_attention_rope_f32_f16kv(
  1041. const float* Q,
  1042. const unsigned short* const* KBlocks,
  1043. const unsigned short* const* VBlocks,
  1044. float* out,
  1045. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  1046. int blockSize,
  1047. float scale, int startPos,
  1048. float theta
  1049. ) {
  1050. if (ensure_rope_step_table(headDim, theta) != 0) {
  1051. return 1;
  1052. }
  1053. // Split-KV Flash Decoding for long contexts.
  1054. const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
  1055. const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
  1056. const int qhCount = seqLen * numHeads;
  1057. const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
  1058. if (!useSplit) {
  1059. dim3 blocks(seqLen, numHeads);
  1060. int threads = 32;
  1061. paged_attention_kernel_f16kv_rope<<<blocks, threads>>>(
  1062. Q, KBlocks, VBlocks, out,
  1063. seqLen, kvLen, numHeads, numKVHeads, headDim,
  1064. blockSize, scale, startPos
  1065. );
  1066. CHECK_CUDA(cudaGetLastError());
  1067. return 0;
  1068. }
  1069. const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
  1070. const size_t totalFloats = splitCount * (size_t)(headDim + 2);
  1071. float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
  1072. if (buf == nullptr) {
  1073. return 1;
  1074. }
  1075. float* partialMax = buf;
  1076. float* partialSum = partialMax + splitCount;
  1077. float* partialOut = partialSum + splitCount;
  1078. dim3 blocks1(seqLen, numHeads, numSplits);
  1079. paged_attention_split_kv_kernel<__half, true><<<blocks1, 32>>>(
  1080. Q,
  1081. reinterpret_cast<const __half* const*>(KBlocks),
  1082. reinterpret_cast<const __half* const*>(VBlocks),
  1083. partialMax, partialSum, partialOut,
  1084. seqLen, kvLen, numHeads, numKVHeads, headDim,
  1085. blockSize,
  1086. scale, startPos,
  1087. numSplits, kPagedAttentionSplitSize
  1088. );
  1089. CHECK_CUDA(cudaGetLastError());
  1090. dim3 blocks2(seqLen, numHeads);
  1091. paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
  1092. partialMax, partialSum, partialOut,
  1093. out,
  1094. seqLen, numHeads, headDim, numSplits
  1095. );
  1096. CHECK_CUDA(cudaGetLastError());
  1097. cuda_free(buf);
  1098. return 0;
  1099. }
  1100. int cuda_paged_attention_batch_f32_f16kv(
  1101. const float* Q,
  1102. const unsigned short* const* KBlocksFlat,
  1103. const unsigned short* const* VBlocksFlat,
  1104. const int* blockOffsets,
  1105. const int* kvLens,
  1106. const int* queryPos,
  1107. float* out,
  1108. int numTokens,
  1109. int numHeads, int numKVHeads, int headDim,
  1110. int blockSize,
  1111. float scale,
  1112. int maxKvLen
  1113. ) {
  1114. const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
  1115. const int qhCount = numTokens * numHeads;
  1116. const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
  1117. if (!useSplit) {
  1118. dim3 blocks(numTokens, numHeads);
  1119. int threads = 32;
  1120. paged_attention_batch_kernel_f16kv<<<blocks, threads>>>(
  1121. Q, KBlocksFlat, VBlocksFlat,
  1122. blockOffsets, kvLens, queryPos,
  1123. out,
  1124. numTokens,
  1125. numHeads, numKVHeads, headDim,
  1126. blockSize,
  1127. scale
  1128. );
  1129. CHECK_CUDA(cudaGetLastError());
  1130. return 0;
  1131. }
  1132. const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
  1133. const size_t totalFloats = splitCount * (size_t)(headDim + 2);
  1134. float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
  1135. if (buf == nullptr) {
  1136. return 1;
  1137. }
  1138. float* partialMax = buf;
  1139. float* partialSum = partialMax + splitCount;
  1140. float* partialOut = partialSum + splitCount;
  1141. dim3 blocks1(numTokens, numHeads, numSplits);
  1142. paged_attention_split_kv_batch_kernel<__half><<<blocks1, 32>>>(
  1143. Q,
  1144. reinterpret_cast<const __half* const*>(KBlocksFlat),
  1145. reinterpret_cast<const __half* const*>(VBlocksFlat),
  1146. blockOffsets, kvLens, queryPos,
  1147. partialMax, partialSum, partialOut,
  1148. numTokens,
  1149. numHeads, numKVHeads, headDim,
  1150. blockSize,
  1151. scale,
  1152. numSplits, kPagedAttentionSplitSize
  1153. );
  1154. CHECK_CUDA(cudaGetLastError());
  1155. dim3 blocks2(numTokens, numHeads);
  1156. paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
  1157. partialMax, partialSum, partialOut,
  1158. out,
  1159. numTokens, numHeads, headDim, numSplits
  1160. );
  1161. CHECK_CUDA(cudaGetLastError());
  1162. cuda_free(buf);
  1163. return 0;
  1164. }
  1165. int cuda_paged_attention_rope_batch_f32_f16kv(
  1166. const float* Q,
  1167. const unsigned short* const* KBlocksFlat,
  1168. const unsigned short* const* VBlocksFlat,
  1169. const int* blockOffsets,
  1170. const int* kvLens,
  1171. const int* queryPos,
  1172. float* out,
  1173. int numTokens,
  1174. int numHeads, int numKVHeads, int headDim,
  1175. int blockSize,
  1176. float scale,
  1177. int maxKvLen,
  1178. float theta
  1179. ) {
  1180. if (ensure_rope_step_table(headDim, theta) != 0) {
  1181. return 1;
  1182. }
  1183. const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
  1184. const int qhCount = numTokens * numHeads;
  1185. const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
  1186. if (!useSplit) {
  1187. dim3 blocks(numTokens, numHeads);
  1188. int threads = 32;
  1189. paged_attention_batch_kernel_f16kv_rope<<<blocks, threads>>>(
  1190. Q, KBlocksFlat, VBlocksFlat,
  1191. blockOffsets, kvLens, queryPos,
  1192. out,
  1193. numTokens,
  1194. numHeads, numKVHeads, headDim,
  1195. blockSize,
  1196. scale
  1197. );
  1198. CHECK_CUDA(cudaGetLastError());
  1199. return 0;
  1200. }
  1201. const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
  1202. const size_t totalFloats = splitCount * (size_t)(headDim + 2);
  1203. float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
  1204. if (buf == nullptr) {
  1205. return 1;
  1206. }
  1207. float* partialMax = buf;
  1208. float* partialSum = partialMax + splitCount;
  1209. float* partialOut = partialSum + splitCount;
  1210. dim3 blocks1(numTokens, numHeads, numSplits);
  1211. paged_attention_split_kv_batch_kernel<__half, true><<<blocks1, 32>>>(
  1212. Q,
  1213. reinterpret_cast<const __half* const*>(KBlocksFlat),
  1214. reinterpret_cast<const __half* const*>(VBlocksFlat),
  1215. blockOffsets, kvLens, queryPos,
  1216. partialMax, partialSum, partialOut,
  1217. numTokens,
  1218. numHeads, numKVHeads, headDim,
  1219. blockSize,
  1220. scale,
  1221. numSplits, kPagedAttentionSplitSize
  1222. );
  1223. CHECK_CUDA(cudaGetLastError());
  1224. dim3 blocks2(numTokens, numHeads);
  1225. paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
  1226. partialMax, partialSum, partialOut,
  1227. out,
  1228. numTokens, numHeads, headDim, numSplits
  1229. );
  1230. CHECK_CUDA(cudaGetLastError());
  1231. cuda_free(buf);
  1232. return 0;
  1233. }
  1234. __global__ void paged_attention_batch_kernel(
  1235. const float* Q,
  1236. const float* const* KBlocksFlat,
  1237. const float* const* VBlocksFlat,
  1238. const int* blockOffsets,
  1239. const int* kvLens,
  1240. const int* queryPos,
  1241. float* out,
  1242. int numTokens,
  1243. int numHeads, int numKVHeads, int headDim,
  1244. int blockSize,
  1245. float scale
  1246. );
  1247. __global__ void paged_attention_batch_kernel_f16kv(
  1248. const float* Q,
  1249. const unsigned short* const* KBlocksFlat,
  1250. const unsigned short* const* VBlocksFlat,
  1251. const int* blockOffsets,
  1252. const int* kvLens,
  1253. const int* queryPos,
  1254. float* out,
  1255. int numTokens,
  1256. int numHeads, int numKVHeads, int headDim,
  1257. int blockSize,
  1258. float scale
  1259. );
  1260. int cuda_paged_attention_batch_f32(
  1261. const float* Q,
  1262. const float* const* KBlocksFlat,
  1263. const float* const* VBlocksFlat,
  1264. const int* blockOffsets,
  1265. const int* kvLens,
  1266. const int* queryPos,
  1267. float* out,
  1268. int numTokens,
  1269. int numHeads, int numKVHeads, int headDim,
  1270. int blockSize,
  1271. float scale,
  1272. int maxKvLen
  1273. ) {
  1274. const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
  1275. const int qhCount = numTokens * numHeads;
  1276. const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
  1277. if (!useSplit) {
  1278. dim3 blocks(numTokens, numHeads);
  1279. int threads = 256;
  1280. paged_attention_batch_kernel<<<blocks, threads>>>(
  1281. Q, KBlocksFlat, VBlocksFlat,
  1282. blockOffsets, kvLens, queryPos,
  1283. out,
  1284. numTokens,
  1285. numHeads, numKVHeads, headDim,
  1286. blockSize,
  1287. scale
  1288. );
  1289. CHECK_CUDA(cudaGetLastError());
  1290. return 0;
  1291. }
  1292. const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
  1293. const size_t totalFloats = splitCount * (size_t)(headDim + 2);
  1294. float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
  1295. if (buf == nullptr) {
  1296. return 1;
  1297. }
  1298. float* partialMax = buf;
  1299. float* partialSum = partialMax + splitCount;
  1300. float* partialOut = partialSum + splitCount;
  1301. dim3 blocks1(numTokens, numHeads, numSplits);
  1302. paged_attention_split_kv_batch_kernel<float><<<blocks1, 32>>>(
  1303. Q,
  1304. KBlocksFlat,
  1305. VBlocksFlat,
  1306. blockOffsets, kvLens, queryPos,
  1307. partialMax, partialSum, partialOut,
  1308. numTokens,
  1309. numHeads, numKVHeads, headDim,
  1310. blockSize,
  1311. scale,
  1312. numSplits, kPagedAttentionSplitSize
  1313. );
  1314. CHECK_CUDA(cudaGetLastError());
  1315. dim3 blocks2(numTokens, numHeads);
  1316. paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
  1317. partialMax, partialSum, partialOut,
  1318. out,
  1319. numTokens, numHeads, headDim, numSplits
  1320. );
  1321. CHECK_CUDA(cudaGetLastError());
  1322. cuda_free(buf);
  1323. return 0;
  1324. }
  1325. // RoPE kernel
  1326. __global__ void rope_kernel(float* x, const int* positions, int totalDim, int headDim, float theta) {
  1327. int seq = blockIdx.x;
  1328. int pos = positions[seq];
  1329. float* rowData = x + seq * totalDim;
  1330. int halfDim = headDim / 2;
  1331. // Each thread handles one (j, j+halfDim) pair across all heads.
  1332. // Compute sin/cos once per j and reuse across heads.
  1333. for (int j = threadIdx.x; j < halfDim; j += blockDim.x) {
  1334. const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim);
  1335. const float freq = pos * invFreq;
  1336. float sinF, cosF;
  1337. sincosf(freq, &sinF, &cosF);
  1338. for (int headStart = 0; headStart < totalDim; headStart += headDim) {
  1339. const int idx0 = headStart + j;
  1340. const int idx1 = idx0 + halfDim;
  1341. const float v0 = rowData[idx0];
  1342. const float v1 = rowData[idx1];
  1343. rowData[idx0] = v0 * cosF - v1 * sinF;
  1344. rowData[idx1] = v1 * cosF + v0 * sinF;
  1345. }
  1346. }
  1347. }
  1348. int cuda_rope_f32(float* x, const int* positions, int seqLen, int numHeads, int headDim, float theta) {
  1349. int threads = (headDim / 2) < 128 ? (headDim / 2) : 128;
  1350. rope_kernel<<<seqLen, threads>>>(x, positions, numHeads * headDim, headDim, theta);
  1351. CHECK_CUDA(cudaGetLastError());
  1352. return 0;
  1353. }
  1354. // Optimized RoPE kernel for single position (scalar pos)
  1355. __global__ void rope_kernel_single(float* x, int pos, int totalDim, int headDim, float theta) {
  1356. // seq is always 0 for single token
  1357. float* rowData = x; // x + 0 * totalDim
  1358. int halfDim = headDim / 2;
  1359. // Each thread handles one (j, j+halfDim) pair across all heads.
  1360. // Compute sin/cos once per j and reuse across heads.
  1361. for (int j = threadIdx.x; j < halfDim; j += blockDim.x) {
  1362. const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim);
  1363. const float freq = pos * invFreq;
  1364. float sinF, cosF;
  1365. sincosf(freq, &sinF, &cosF);
  1366. for (int headStart = 0; headStart < totalDim; headStart += headDim) {
  1367. const int idx0 = headStart + j;
  1368. const int idx1 = idx0 + halfDim;
  1369. const float v0 = rowData[idx0];
  1370. const float v1 = rowData[idx1];
  1371. rowData[idx0] = v0 * cosF - v1 * sinF;
  1372. rowData[idx1] = v1 * cosF + v0 * sinF;
  1373. }
  1374. }
  1375. }
  1376. int cuda_rope_f32_single(float* x, int pos, int numHeads, int headDim, float theta) {
  1377. int threads = (headDim / 2) < 128 ? (headDim / 2) : 128;
  1378. rope_kernel_single<<<1, threads>>>(x, pos, numHeads * headDim, headDim, theta);
  1379. CHECK_CUDA(cudaGetLastError());
  1380. return 0;
  1381. }
  1382. // Softmax kernel: one block per row
  1383. __global__ void softmax_kernel(float* x, int cols) {
  1384. int row = blockIdx.x;
  1385. float* rowData = x + row * cols;
  1386. __shared__ float smax[256];
  1387. __shared__ float ssum[256];
  1388. // Find max
  1389. float threadMax = -INFINITY;
  1390. for (int i = threadIdx.x; i < cols; i += blockDim.x) {
  1391. threadMax = fmaxf(threadMax, rowData[i]);
  1392. }
  1393. smax[threadIdx.x] = threadMax;
  1394. __syncthreads();
  1395. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  1396. if (threadIdx.x < s) {
  1397. smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
  1398. }
  1399. __syncthreads();
  1400. }
  1401. float maxVal = smax[0];
  1402. // Compute exp and sum
  1403. float threadSum = 0.0f;
  1404. for (int i = threadIdx.x; i < cols; i += blockDim.x) {
  1405. float val = expf(rowData[i] - maxVal);
  1406. rowData[i] = val;
  1407. threadSum += val;
  1408. }
  1409. ssum[threadIdx.x] = threadSum;
  1410. __syncthreads();
  1411. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  1412. if (threadIdx.x < s) {
  1413. ssum[threadIdx.x] += ssum[threadIdx.x + s];
  1414. }
  1415. __syncthreads();
  1416. }
  1417. float sum = ssum[0];
  1418. // Normalize
  1419. for (int i = threadIdx.x; i < cols; i += blockDim.x) {
  1420. rowData[i] /= sum;
  1421. }
  1422. }
  1423. int cuda_softmax_f32(float* x, int rows, int cols) {
  1424. int threads = 256;
  1425. softmax_kernel<<<rows, threads>>>(x, cols);
  1426. CHECK_CUDA(cudaGetLastError());
  1427. return 0;
  1428. }
  1429. // ============================================================
  1430. // Top-K Logits Selection (for sampling without full D2H)
  1431. // ============================================================
  1432. #define TOPK_THREADS 256
  1433. #define TOPK_LOCAL 8
  1434. #define TOPK_SEGMENT (TOPK_THREADS * TOPK_LOCAL) // 2048
  1435. #define TOPK_MAX_K 64
  1436. static __device__ __forceinline__ float apply_rep_penalty(float score, bool hit, float penalty) {
  1437. if (!hit) return score;
  1438. if (score > 0.0f) return score / penalty;
  1439. return score * penalty;
  1440. }
  1441. __global__ void topk_logits_kernel(
  1442. const float* logits, int vocab,
  1443. const int* rep_ids, int rep_count, float rep_penalty,
  1444. int k,
  1445. int* out_ids, float* out_scores
  1446. ) {
  1447. const int blockStart = blockIdx.x * TOPK_SEGMENT;
  1448. const int tid = threadIdx.x;
  1449. if (tid >= TOPK_THREADS) return;
  1450. float localScores[TOPK_LOCAL];
  1451. int localIds[TOPK_LOCAL];
  1452. #pragma unroll
  1453. for (int i = 0; i < TOPK_LOCAL; i++) {
  1454. localScores[i] = -INFINITY;
  1455. localIds[i] = -1;
  1456. }
  1457. // Each thread processes up to TOPK_LOCAL elements in a contiguous segment.
  1458. #pragma unroll
  1459. for (int j = 0; j < TOPK_LOCAL; j++) {
  1460. int idx = blockStart + tid + j * TOPK_THREADS;
  1461. if (idx >= vocab) break;
  1462. float score = logits[idx];
  1463. bool hit = false;
  1464. // rep_count is small (<=64). Linear scan is fine.
  1465. for (int r = 0; r < rep_count; r++) {
  1466. if (rep_ids[r] == idx) {
  1467. hit = true;
  1468. break;
  1469. }
  1470. }
  1471. score = apply_rep_penalty(score, hit, rep_penalty);
  1472. // Insert into local top list (descending)
  1473. int pos = TOPK_LOCAL;
  1474. for (int t = 0; t < TOPK_LOCAL; t++) {
  1475. if (score > localScores[t]) { pos = t; break; }
  1476. }
  1477. if (pos < TOPK_LOCAL) {
  1478. for (int t = TOPK_LOCAL - 1; t > pos; t--) {
  1479. localScores[t] = localScores[t-1];
  1480. localIds[t] = localIds[t-1];
  1481. }
  1482. localScores[pos] = score;
  1483. localIds[pos] = idx;
  1484. }
  1485. }
  1486. __shared__ float shScores[TOPK_SEGMENT];
  1487. __shared__ int shIds[TOPK_SEGMENT];
  1488. // Write local results to shared candidate pool.
  1489. #pragma unroll
  1490. for (int j = 0; j < TOPK_LOCAL; j++) {
  1491. int out = tid * TOPK_LOCAL + j;
  1492. shScores[out] = localScores[j];
  1493. shIds[out] = localIds[j];
  1494. }
  1495. __syncthreads();
  1496. if (tid == 0) {
  1497. // Block-level exact top-k from TOPK_SEGMENT candidates.
  1498. if (k > TOPK_MAX_K) k = TOPK_MAX_K;
  1499. float bestScores[TOPK_MAX_K];
  1500. int bestIds[TOPK_MAX_K];
  1501. for (int i = 0; i < k; i++) {
  1502. bestScores[i] = -INFINITY;
  1503. bestIds[i] = -1;
  1504. }
  1505. for (int i = 0; i < TOPK_SEGMENT; i++) {
  1506. float score = shScores[i];
  1507. int id = shIds[i];
  1508. if (id < 0) continue;
  1509. // Insert into best (descending)
  1510. if (score <= bestScores[k-1]) continue;
  1511. int pos = k;
  1512. for (int t = 0; t < k; t++) {
  1513. if (score > bestScores[t]) { pos = t; break; }
  1514. }
  1515. if (pos < k) {
  1516. for (int t = k - 1; t > pos; t--) {
  1517. bestScores[t] = bestScores[t-1];
  1518. bestIds[t] = bestIds[t-1];
  1519. }
  1520. bestScores[pos] = score;
  1521. bestIds[pos] = id;
  1522. }
  1523. }
  1524. int base = blockIdx.x * k;
  1525. for (int i = 0; i < k; i++) {
  1526. out_ids[base + i] = bestIds[i];
  1527. out_scores[base + i] = bestScores[i];
  1528. }
  1529. }
  1530. }
  1531. int cuda_topk_logits_f32(
  1532. const float* logits, int vocab,
  1533. const int* rep_ids, int rep_count, float rep_penalty,
  1534. int k,
  1535. int* out_ids, float* out_scores
  1536. ) {
  1537. if (k <= 0) return 0;
  1538. if (k > TOPK_MAX_K) k = TOPK_MAX_K;
  1539. int blocks = (vocab + TOPK_SEGMENT - 1) / TOPK_SEGMENT;
  1540. dim3 grid(blocks);
  1541. dim3 block(TOPK_THREADS);
  1542. topk_logits_kernel<<<grid, block>>>(logits, vocab, rep_ids, rep_count, rep_penalty, k, out_ids, out_scores);
  1543. CHECK_CUDA(cudaGetLastError());
  1544. return 0;
  1545. }
  1546. // ============================================================
  1547. // Attention Kernel
  1548. // Computes: softmax(Q @ K.T / scale + causal_mask) @ V
  1549. // ============================================================
  1550. __global__ void attention_kernel(
  1551. const float* Q, const float* K, const float* V, float* out,
  1552. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  1553. float scale, int startPos
  1554. ) {
  1555. // Each block handles one (seq, head) pair
  1556. int seq = blockIdx.x;
  1557. int head = blockIdx.y;
  1558. int kvHead = head / (numHeads / numKVHeads); // GQA support
  1559. const float* q = Q + seq * numHeads * headDim + head * headDim;
  1560. float* o = out + seq * numHeads * headDim + head * headDim;
  1561. // Shared memory for attention scores
  1562. extern __shared__ float shared[];
  1563. float* scores = shared; // [kvLen]
  1564. // Compute Q @ K.T for this head
  1565. for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
  1566. const float* k = K + kv * numKVHeads * headDim + kvHead * headDim;
  1567. float dot = 0.0f;
  1568. for (int d = 0; d < headDim; d++) {
  1569. dot += q[d] * k[d];
  1570. }
  1571. // Apply causal mask
  1572. int queryPos = startPos + seq;
  1573. int keyPos = kv;
  1574. if (keyPos > queryPos) {
  1575. dot = -INFINITY;
  1576. }
  1577. scores[kv] = dot * scale;
  1578. }
  1579. __syncthreads();
  1580. // Softmax over scores
  1581. // Find max
  1582. float maxVal = -INFINITY;
  1583. for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
  1584. maxVal = fmaxf(maxVal, scores[kv]);
  1585. }
  1586. // Reduce max across threads
  1587. __shared__ float smax[256];
  1588. smax[threadIdx.x] = maxVal;
  1589. __syncthreads();
  1590. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  1591. if (threadIdx.x < s) {
  1592. smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
  1593. }
  1594. __syncthreads();
  1595. }
  1596. maxVal = smax[0];
  1597. // Exp and sum
  1598. float sum = 0.0f;
  1599. for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
  1600. float val = expf(scores[kv] - maxVal);
  1601. scores[kv] = val;
  1602. sum += val;
  1603. }
  1604. __shared__ float ssum[256];
  1605. ssum[threadIdx.x] = sum;
  1606. __syncthreads();
  1607. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  1608. if (threadIdx.x < s) {
  1609. ssum[threadIdx.x] += ssum[threadIdx.x + s];
  1610. }
  1611. __syncthreads();
  1612. }
  1613. sum = ssum[0];
  1614. // Normalize
  1615. for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
  1616. scores[kv] /= sum;
  1617. }
  1618. __syncthreads();
  1619. // Compute weighted sum of V
  1620. for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
  1621. float val = 0.0f;
  1622. for (int kv = 0; kv < kvLen; kv++) {
  1623. const float* v = V + kv * numKVHeads * headDim + kvHead * headDim;
  1624. val += scores[kv] * v[d];
  1625. }
  1626. o[d] = val;
  1627. }
  1628. }
  1629. __global__ void paged_attention_kernel(
  1630. const float* Q,
  1631. const float* const* KBlocks,
  1632. const float* const* VBlocks,
  1633. float* out,
  1634. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  1635. int blockSize,
  1636. float scale, int startPos
  1637. ) {
  1638. int seq = blockIdx.x;
  1639. int head = blockIdx.y;
  1640. int kvHead = head / (numHeads / numKVHeads);
  1641. const float* q = Q + seq * numHeads * headDim + head * headDim;
  1642. float* o = out + seq * numHeads * headDim + head * headDim;
  1643. const int kvStride = numKVHeads * headDim;
  1644. int queryPos = startPos + seq;
  1645. // Pass 1: max
  1646. float localMax = -INFINITY;
  1647. for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
  1648. int keyPos = kv;
  1649. if (keyPos > queryPos) {
  1650. continue;
  1651. }
  1652. const int blockIdxKV = kv / blockSize;
  1653. const int blockOff = kv % blockSize;
  1654. const float* kBlock = KBlocks[blockIdxKV];
  1655. const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
  1656. float dot = 0.0f;
  1657. for (int d = 0; d < headDim; d++) {
  1658. dot += q[d] * k[d];
  1659. }
  1660. float score = dot * scale;
  1661. localMax = fmaxf(localMax, score);
  1662. }
  1663. __shared__ float smax[256];
  1664. smax[threadIdx.x] = localMax;
  1665. __syncthreads();
  1666. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  1667. if (threadIdx.x < s) {
  1668. smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
  1669. }
  1670. __syncthreads();
  1671. }
  1672. float maxVal = smax[0];
  1673. // Pass 2: sum exp
  1674. float localSum = 0.0f;
  1675. for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
  1676. int keyPos = kv;
  1677. if (keyPos > queryPos) {
  1678. continue;
  1679. }
  1680. const int blockIdxKV = kv / blockSize;
  1681. const int blockOff = kv % blockSize;
  1682. const float* kBlock = KBlocks[blockIdxKV];
  1683. const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
  1684. float dot = 0.0f;
  1685. for (int d = 0; d < headDim; d++) {
  1686. dot += q[d] * k[d];
  1687. }
  1688. float score = dot * scale;
  1689. localSum += expf(score - maxVal);
  1690. }
  1691. __shared__ float ssum[256];
  1692. ssum[threadIdx.x] = localSum;
  1693. __syncthreads();
  1694. for (int s = blockDim.x / 2; s > 0; s >>= 1) {
  1695. if (threadIdx.x < s) {
  1696. ssum[threadIdx.x] += ssum[threadIdx.x + s];
  1697. }
  1698. __syncthreads();
  1699. }
  1700. float sumVal = ssum[0];
  1701. float invSum = (sumVal > 0.0f) ? (1.0f / sumVal) : 0.0f;
  1702. // Pass 3: output accumulation
  1703. for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
  1704. float outVal = 0.0f;
  1705. for (int kv = 0; kv < kvLen; kv++) {
  1706. int keyPos = kv;
  1707. if (keyPos > queryPos) {
  1708. break;
  1709. }
  1710. const int blockIdxKV = kv / blockSize;
  1711. const int blockOff = kv % blockSize;
  1712. const float* kBlock = KBlocks[blockIdxKV];
  1713. const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
  1714. float dot = 0.0f;
  1715. for (int kd = 0; kd < headDim; kd++) {
  1716. dot += q[kd] * k[kd];
  1717. }
  1718. float score = dot * scale;
  1719. float w = expf(score - maxVal) * invSum;
  1720. const float* vBlock = VBlocks[blockIdxKV];
  1721. const float* v = vBlock + blockOff * kvStride + kvHead * headDim;
  1722. outVal += w * v[d];
  1723. }
  1724. o[d] = outVal;
  1725. }
  1726. }
  1727. __global__ void paged_attention_batch_kernel(
  1728. const float* Q,
  1729. const float* const* KBlocksFlat,
  1730. const float* const* VBlocksFlat,
  1731. const int* blockOffsets,
  1732. const int* kvLens,
  1733. const int* queryPos,
  1734. float* out,
  1735. int numTokens,
  1736. int numHeads, int numKVHeads, int headDim,
  1737. int blockSize,
  1738. float scale
  1739. ) {
  1740. int tok = blockIdx.x;
  1741. int head = blockIdx.y;
  1742. if (tok >= numTokens) {
  1743. return;
  1744. }
  1745. int kvHead = head / (numHeads / numKVHeads);
  1746. const float* q = Q + tok * numHeads * headDim + head * headDim;
  1747. float* o = out + tok * numHeads * headDim + head * headDim;
  1748. int kvLen = kvLens[tok];
  1749. int qPos = queryPos[tok];
  1750. int base = blockOffsets[tok];
  1751. const int kvStride = numKVHeads * headDim;
  1752. const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
  1753. float acc = 0.0f;
  1754. if (threadIdx.x >= headDim) {
  1755. acc = 0.0f;
  1756. }
  1757. __shared__ float m;
  1758. __shared__ float l;
  1759. __shared__ float alpha;
  1760. __shared__ float beta;
  1761. __shared__ float dotShared;
  1762. if (threadIdx.x == 0) {
  1763. m = -INFINITY;
  1764. l = 0.0f;
  1765. }
  1766. __syncthreads();
  1767. for (int kv = 0; kv < effectiveLen; kv++) {
  1768. const int bidx = kv / blockSize;
  1769. const int boff = kv % blockSize;
  1770. const float* kBlock = KBlocksFlat[base + bidx];
  1771. const float* k = kBlock + boff * kvStride + kvHead * headDim;
  1772. float partial = 0.0f;
  1773. for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
  1774. partial = fmaf(q[d], k[d], partial);
  1775. }
  1776. // block reduction (sum)
  1777. for (int offset = 16; offset > 0; offset >>= 1) {
  1778. partial += __shfl_down_sync(0xffffffff, partial, offset);
  1779. }
  1780. __shared__ float warpSum[8];
  1781. int lane = threadIdx.x & 31;
  1782. int warp = threadIdx.x >> 5;
  1783. if (lane == 0) {
  1784. warpSum[warp] = partial;
  1785. }
  1786. __syncthreads();
  1787. if (warp == 0) {
  1788. float v = (lane < 8) ? warpSum[lane] : 0.0f;
  1789. for (int offset = 16; offset > 0; offset >>= 1) {
  1790. v += __shfl_down_sync(0xffffffff, v, offset);
  1791. }
  1792. if (lane == 0) {
  1793. dotShared = v;
  1794. }
  1795. }
  1796. __syncthreads();
  1797. float score = dotShared * scale;
  1798. if (threadIdx.x == 0) {
  1799. float newM = fmaxf(m, score);
  1800. float a = expf(m - newM);
  1801. float b = expf(score - newM);
  1802. m = newM;
  1803. l = l * a + b;
  1804. alpha = a;
  1805. beta = b;
  1806. }
  1807. __syncthreads();
  1808. if (threadIdx.x < headDim) {
  1809. const float* vBlock = VBlocksFlat[base + bidx];
  1810. const float* v = vBlock + boff * kvStride + kvHead * headDim;
  1811. acc = fmaf(beta, v[threadIdx.x], acc * alpha);
  1812. }
  1813. __syncthreads();
  1814. }
  1815. if (threadIdx.x < headDim) {
  1816. float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
  1817. o[threadIdx.x] = acc * invL;
  1818. }
  1819. }
  1820. int cuda_attention_f32(
  1821. const float* Q, const float* K, const float* V, float* out,
  1822. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  1823. float scale, int startPos
  1824. ) {
  1825. dim3 blocks(seqLen, numHeads);
  1826. int threads = 256;
  1827. size_t sharedMem = kvLen * sizeof(float);
  1828. attention_kernel<<<blocks, threads, sharedMem>>>(
  1829. Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos
  1830. );
  1831. CHECK_CUDA(cudaGetLastError());
  1832. return 0;
  1833. }
  1834. int cuda_paged_attention_f32(
  1835. const float* Q,
  1836. const float* const* KBlocks,
  1837. const float* const* VBlocks,
  1838. float* out,
  1839. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  1840. int blockSize,
  1841. float scale, int startPos
  1842. ) {
  1843. // Split-KV Flash Decoding for long contexts.
  1844. const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
  1845. const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
  1846. const int qhCount = seqLen * numHeads;
  1847. const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
  1848. if (!useSplit) {
  1849. dim3 blocks(seqLen, numHeads);
  1850. int threads = 256;
  1851. paged_attention_kernel<<<blocks, threads>>>(
  1852. Q, KBlocks, VBlocks, out,
  1853. seqLen, kvLen, numHeads, numKVHeads, headDim,
  1854. blockSize, scale, startPos
  1855. );
  1856. CHECK_CUDA(cudaGetLastError());
  1857. return 0;
  1858. }
  1859. const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
  1860. const size_t totalFloats = splitCount * (size_t)(headDim + 2);
  1861. float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
  1862. if (buf == nullptr) {
  1863. return 1;
  1864. }
  1865. float* partialMax = buf;
  1866. float* partialSum = partialMax + splitCount;
  1867. float* partialOut = partialSum + splitCount;
  1868. dim3 blocks1(seqLen, numHeads, numSplits);
  1869. paged_attention_split_kv_kernel<float><<<blocks1, 32>>>(
  1870. Q, KBlocks, VBlocks,
  1871. partialMax, partialSum, partialOut,
  1872. seqLen, kvLen, numHeads, numKVHeads, headDim,
  1873. blockSize,
  1874. scale, startPos,
  1875. numSplits, kPagedAttentionSplitSize
  1876. );
  1877. CHECK_CUDA(cudaGetLastError());
  1878. dim3 blocks2(seqLen, numHeads);
  1879. paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
  1880. partialMax, partialSum, partialOut,
  1881. out,
  1882. seqLen, numHeads, headDim, numSplits
  1883. );
  1884. CHECK_CUDA(cudaGetLastError());
  1885. cuda_free(buf);
  1886. return 0;
  1887. }
  1888. int cuda_attention_f32_timed(
  1889. const float* Q, const float* K, const float* V, float* out,
  1890. int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
  1891. float scale, int startPos, float* ms
  1892. ) {
  1893. cudaEvent_t evStart;
  1894. cudaEvent_t evStop;
  1895. CHECK_CUDA(cudaEventCreate(&evStart));
  1896. CHECK_CUDA(cudaEventCreate(&evStop));
  1897. dim3 blocks(seqLen, numHeads);
  1898. int threads = 256;
  1899. size_t sharedMem = kvLen * sizeof(float);
  1900. CHECK_CUDA(cudaEventRecord(evStart));
  1901. attention_kernel<<<blocks, threads, sharedMem>>>(
  1902. Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos
  1903. );
  1904. CHECK_CUDA(cudaEventRecord(evStop));
  1905. CHECK_CUDA(cudaEventSynchronize(evStop));
  1906. float elapsed = 0.0f;
  1907. CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
  1908. if (ms != NULL) {
  1909. *ms = elapsed;
  1910. }
  1911. CHECK_CUDA(cudaEventDestroy(evStart));
  1912. CHECK_CUDA(cudaEventDestroy(evStop));
  1913. CHECK_CUDA(cudaGetLastError());
  1914. return 0;
  1915. }
  1916. // Debug helper
  1917. int cuda_print_struct_sizes() {
  1918. printf("GPU Struct Sizes:\n");
  1919. printf("BlockQ2_K: %lu\n", sizeof(BlockQ2_K));
  1920. printf("BlockQ3_K: %lu\n", sizeof(BlockQ3_K));
  1921. printf("BlockQ4_K: %lu\n", sizeof(BlockQ4_K));
  1922. printf("BlockQ6_K: %lu\n", sizeof(BlockQ6_K));
  1923. printf("BlockQ8_K: %lu\n", sizeof(BlockQ8_K));
  1924. return 0;
  1925. }