| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165 |
- #include "cuda_common.cuh"
- #include <cuda_fp16.h>
- #include <stdint.h>
- #include <string.h>
- namespace {
- constexpr int kPagedAttentionSplitSize = 1024;
- constexpr int kPagedAttentionSplitQHThreshold = 4096; // queryCount*numHeads threshold
- } // namespace
- // ============================================================
- // Fused RoPE helpers (constant step table + fast complex pow)
- // ============================================================
- // Stores cos/sin of per-dimension "step" angle (invFreq) for RoPE.
- // Only indices [0, headDim/2) are used. Max supported headDim is 256.
- __device__ __constant__ float rope_cos_step_const[128];
- __device__ __constant__ float rope_sin_step_const[128];
- static int g_rope_step_inited[32] = {0};
- static int g_rope_step_head_dim[32] = {0};
- static uint32_t g_rope_step_theta_bits[32] = {0};
- static int ensure_rope_step_table(int headDim, float theta) {
- if (headDim <= 0 || headDim > 256 || (headDim & 1) != 0) {
- return 1;
- }
- int dev = 0;
- CHECK_CUDA(cudaGetDevice(&dev));
- if (dev < 0 || dev >= 32) {
- dev = 0;
- }
- uint32_t thetaBits = 0;
- memcpy(&thetaBits, &theta, sizeof(thetaBits));
- if (g_rope_step_inited[dev] && g_rope_step_head_dim[dev] == headDim && g_rope_step_theta_bits[dev] == thetaBits) {
- return 0;
- }
- float cosStep[128];
- float sinStep[128];
- const int halfDim = headDim / 2;
- for (int j = 0; j < 128; j++) {
- if (j < halfDim) {
- // invFreq = theta^(-2j/headDim)
- const double exp = -2.0 * (double)j / (double)headDim;
- const double invFreq = pow((double)theta, exp);
- cosStep[j] = (float)cos(invFreq);
- sinStep[j] = (float)sin(invFreq);
- } else {
- cosStep[j] = 1.0f;
- sinStep[j] = 0.0f;
- }
- }
- CHECK_CUDA(cudaMemcpyToSymbol(rope_cos_step_const, cosStep, sizeof(cosStep), 0, cudaMemcpyHostToDevice));
- CHECK_CUDA(cudaMemcpyToSymbol(rope_sin_step_const, sinStep, sizeof(sinStep), 0, cudaMemcpyHostToDevice));
- g_rope_step_inited[dev] = 1;
- g_rope_step_head_dim[dev] = headDim;
- g_rope_step_theta_bits[dev] = thetaBits;
- return 0;
- }
- __device__ __forceinline__ float2 complex_mul_f2(float2 a, float2 b) {
- // (a.x + i a.y) * (b.x + i b.y)
- return make_float2(
- fmaf(a.x, b.x, -a.y * b.y),
- fmaf(a.x, b.y, a.y * b.x)
- );
- }
- __device__ __forceinline__ float2 complex_pow_int(float2 base, int exp) {
- float2 result = make_float2(1.0f, 0.0f);
- float2 b = base;
- int e = exp;
- while (e > 0) {
- if (e & 1) {
- result = complex_mul_f2(result, b);
- }
- b = complex_mul_f2(b, b);
- e >>= 1;
- }
- return result;
- }
- __device__ __forceinline__ void rope_advance_neg(float& cosv, float& sinv, float cosStep, float sinStep) {
- // Multiply by exp(-i*step): (cos + i sin) * (cosStep - i sinStep)
- const float c = cosv;
- const float s = sinv;
- cosv = fmaf(c, cosStep, s * sinStep);
- sinv = fmaf(s, cosStep, -c * sinStep);
- }
- // ============================================================
- // Neural Network Operations
- // ============================================================
- // RMSNorm kernel: one block per row
- __global__ void rmsnorm_kernel(float* x, const float* w, int dim, float eps) {
- int row = blockIdx.x;
- float* rowData = x + row * dim;
- float sum = 0.0f;
- for (int i = threadIdx.x; i < dim; i += blockDim.x) {
- float v = rowData[i];
- sum = fmaf(v, v, sum);
- }
- // Warp reduce
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- __shared__ float warpSum[8];
- __shared__ float rms;
- int lane = threadIdx.x & 31;
- int warp = threadIdx.x >> 5;
- if (lane == 0) {
- warpSum[warp] = sum;
- }
- __syncthreads();
- if (warp == 0) {
- float v = (lane < 8) ? warpSum[lane] : 0.0f;
- for (int offset = 16; offset > 0; offset >>= 1) {
- v += __shfl_down_sync(0xffffffff, v, offset);
- }
- if (lane == 0) {
- rms = rsqrtf(v / dim + eps);
- }
- }
- __syncthreads();
- for (int i = threadIdx.x; i < dim; i += blockDim.x) {
- rowData[i] = rowData[i] * rms * w[i];
- }
- }
- __global__ void paged_attention_batch_kernel_f16kv(
- const float* Q,
- const unsigned short* const* KBlocksFlat,
- const unsigned short* const* VBlocksFlat,
- const int* blockOffsets,
- const int* kvLens,
- const int* queryPos,
- float* out,
- int numTokens,
- int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale
- ) {
- const int tok = blockIdx.x;
- const int head = blockIdx.y;
- const int lane = threadIdx.x & 31;
- if (tok >= numTokens) {
- return;
- }
- // One warp per (tok, head)
- if (threadIdx.x >= 32) {
- return;
- }
- const int kvHead = head / (numHeads / numKVHeads);
- const float* q = Q + tok * numHeads * headDim + head * headDim;
- float* o = out + tok * numHeads * headDim + head * headDim;
- const int kvLen = kvLens[tok];
- const int qPos = queryPos[tok];
- const int base = blockOffsets[tok];
- const int kvStride = numKVHeads * headDim;
- const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
- // Cache Q in registers (per lane) to avoid reloading it for every KV token.
- float qreg[8];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- qreg[i] = (d < headDim) ? q[d] : 0.0f;
- }
- // Support headDim up to 256 (<= 8 values per lane)
- float acc[8];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- acc[i] = 0.0f;
- }
- float m = -INFINITY;
- float l = 0.0f;
- for (int kv = 0; kv < effectiveLen; kv++) {
- const int bidx = kv / blockSize;
- const int boff = kv % blockSize;
- const __half* kBlock = reinterpret_cast<const __half*>(KBlocksFlat[base + bidx]);
- const __half* k = kBlock + boff * kvStride + kvHead * headDim;
- float dot = 0.0f;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- dot = fmaf(qreg[i], __half2float(k[d]), dot);
- }
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- dot += __shfl_down_sync(0xffffffff, dot, offset);
- }
- dot = __shfl_sync(0xffffffff, dot, 0);
- const float score = dot * scale;
- float alpha = 1.0f;
- float beta = 0.0f;
- if (lane == 0) {
- const float newM = fmaxf(m, score);
- alpha = __expf(m - newM);
- beta = __expf(score - newM);
- m = newM;
- l = l * alpha + beta;
- }
- alpha = __shfl_sync(0xffffffff, alpha, 0);
- beta = __shfl_sync(0xffffffff, beta, 0);
- l = __shfl_sync(0xffffffff, l, 0);
- const __half* vBlock = reinterpret_cast<const __half*>(VBlocksFlat[base + bidx]);
- const __half* v = vBlock + boff * kvStride + kvHead * headDim;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
- }
- }
- }
- const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- o[d] = acc[i] * invL;
- }
- }
- }
- // Fused RoPE + paged attention (batch, f16 KV). Expects un-rotated Q/K.
- __global__ void paged_attention_batch_kernel_f16kv_rope(
- const float* Q,
- const unsigned short* const* KBlocksFlat,
- const unsigned short* const* VBlocksFlat,
- const int* blockOffsets,
- const int* kvLens,
- const int* queryPos,
- float* out,
- int numTokens,
- int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale
- ) {
- const int tok = blockIdx.x;
- const int head = blockIdx.y;
- const int lane = threadIdx.x & 31;
- if (tok >= numTokens) {
- return;
- }
- // One warp per (tok, head)
- if (threadIdx.x >= 32) {
- return;
- }
- const int kvHead = head / (numHeads / numKVHeads);
- const float* q = Q + tok * numHeads * headDim + head * headDim;
- float* o = out + tok * numHeads * headDim + head * headDim;
- const int kvLen = kvLens[tok];
- const int qPos = queryPos[tok];
- const int base = blockOffsets[tok];
- const int kvStride = numKVHeads * headDim;
- const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
- const int halfDim = headDim >> 1;
- // Cache Q pairs + per-dim RoPE phase for delta=(qPos - kv) with kv starting at 0.
- float q0[4];
- float q1[4];
- float cosStep[4];
- float sinStep[4];
- float cosDelta[4];
- float sinDelta[4];
- int pairCount = 0;
- for (int j = lane; j < halfDim; j += 32) {
- q0[pairCount] = q[j];
- q1[pairCount] = q[j + halfDim];
- cosStep[pairCount] = rope_cos_step_const[j];
- sinStep[pairCount] = rope_sin_step_const[j];
- const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
- const float2 ph = complex_pow_int(baseStep, qPos);
- cosDelta[pairCount] = ph.x;
- sinDelta[pairCount] = ph.y;
- pairCount++;
- }
- // Support headDim up to 256.
- float acc[8];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- acc[i] = 0.0f;
- }
- float m = -INFINITY;
- float l = 0.0f;
- for (int kv = 0; kv < effectiveLen; kv++) {
- const int bidx = kv / blockSize;
- const int boff = kv % blockSize;
- const __half* kBlock = reinterpret_cast<const __half*>(KBlocksFlat[base + bidx]);
- const __half* k = kBlock + boff * kvStride + kvHead * headDim;
- float dot = 0.0f;
- #pragma unroll
- for (int pi = 0; pi < 4; pi++) {
- if (pi >= pairCount) {
- break;
- }
- const int j = lane + 32 * pi;
- if (j >= halfDim) {
- continue;
- }
- const float k0 = __half2float(k[j]);
- const float k1 = __half2float(k[j + halfDim]);
- const float a = fmaf(q0[pi], k0, q1[pi] * k1); // q0*k0 + q1*k1
- const float b = fmaf(q0[pi], k1, -q1[pi] * k0); // q0*k1 - q1*k0
- dot = fmaf(cosDelta[pi], a, dot);
- dot = fmaf(sinDelta[pi], b, dot);
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- dot += __shfl_down_sync(0xffffffff, dot, offset);
- }
- dot = __shfl_sync(0xffffffff, dot, 0);
- // Advance delta -> delta-1 for next kv.
- #pragma unroll
- for (int pi = 0; pi < 4; pi++) {
- if (pi >= pairCount) {
- break;
- }
- rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
- }
- const float score = dot * scale;
- float alpha = 1.0f;
- float beta = 0.0f;
- if (lane == 0) {
- const float newM = fmaxf(m, score);
- alpha = __expf(m - newM);
- beta = __expf(score - newM);
- m = newM;
- l = l * alpha + beta;
- }
- alpha = __shfl_sync(0xffffffff, alpha, 0);
- beta = __shfl_sync(0xffffffff, beta, 0);
- l = __shfl_sync(0xffffffff, l, 0);
- const __half* vBlock = reinterpret_cast<const __half*>(VBlocksFlat[base + bidx]);
- const __half* v = vBlock + boff * kvStride + kvHead * headDim;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
- }
- }
- }
- const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- o[d] = acc[i] * invL;
- }
- }
- }
- __global__ void cast_f32_to_f16_kernel(const float* src, __half* dst, int n) {
- int i = blockIdx.x * blockDim.x + threadIdx.x;
- if (i < n) {
- dst[i] = __float2half_rn(src[i]);
- }
- }
- __global__ void paged_attention_kernel_f16kv(
- const float* Q,
- const unsigned short* const* KBlocks,
- const unsigned short* const* VBlocks,
- float* out,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale, int startPos
- ) {
- const int seq = blockIdx.x;
- const int head = blockIdx.y;
- const int lane = threadIdx.x & 31;
- if (seq >= seqLen) {
- return;
- }
- // One warp per (seq, head)
- if (threadIdx.x >= 32) {
- return;
- }
- const int kvHead = head / (numHeads / numKVHeads);
- const float* q = Q + seq * numHeads * headDim + head * headDim;
- float* o = out + seq * numHeads * headDim + head * headDim;
- const int kvStride = numKVHeads * headDim;
- const int queryPos = startPos + seq;
- const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
- // Cache Q in registers (per lane) to avoid reloading it for every KV token.
- float qreg[8];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- qreg[i] = (d < headDim) ? q[d] : 0.0f;
- }
- float acc[8];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- acc[i] = 0.0f;
- }
- float m = -INFINITY;
- float l = 0.0f;
- for (int kv = 0; kv < effectiveLen; kv++) {
- const int blockIdxKV = kv / blockSize;
- const int blockOff = kv % blockSize;
- const __half* kBlock = reinterpret_cast<const __half*>(KBlocks[blockIdxKV]);
- const __half* k = kBlock + blockOff * kvStride + kvHead * headDim;
- float dot = 0.0f;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- dot = fmaf(qreg[i], __half2float(k[d]), dot);
- }
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- dot += __shfl_down_sync(0xffffffff, dot, offset);
- }
- dot = __shfl_sync(0xffffffff, dot, 0);
- const float score = dot * scale;
- float alpha = 1.0f;
- float beta = 0.0f;
- if (lane == 0) {
- const float newM = fmaxf(m, score);
- alpha = __expf(m - newM);
- beta = __expf(score - newM);
- m = newM;
- l = l * alpha + beta;
- }
- alpha = __shfl_sync(0xffffffff, alpha, 0);
- beta = __shfl_sync(0xffffffff, beta, 0);
- l = __shfl_sync(0xffffffff, l, 0);
- const __half* vBlock = reinterpret_cast<const __half*>(VBlocks[blockIdxKV]);
- const __half* v = vBlock + blockOff * kvStride + kvHead * headDim;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
- }
- }
- }
- const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- o[d] = acc[i] * invL;
- }
- }
- }
- // Fused RoPE + paged attention (single, f16 KV). Expects un-rotated Q/K.
- __global__ void paged_attention_kernel_f16kv_rope(
- const float* Q,
- const unsigned short* const* KBlocks,
- const unsigned short* const* VBlocks,
- float* out,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale, int startPos
- ) {
- const int seq = blockIdx.x;
- const int head = blockIdx.y;
- const int lane = threadIdx.x & 31;
- if (seq >= seqLen) {
- return;
- }
- // One warp per (seq, head)
- if (threadIdx.x >= 32) {
- return;
- }
- const int kvHead = head / (numHeads / numKVHeads);
- const float* q = Q + seq * numHeads * headDim + head * headDim;
- float* o = out + seq * numHeads * headDim + head * headDim;
- const int kvStride = numKVHeads * headDim;
- const int queryPos = startPos + seq;
- const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
- const int halfDim = headDim >> 1;
- float q0[4];
- float q1[4];
- float cosStep[4];
- float sinStep[4];
- float cosDelta[4];
- float sinDelta[4];
- int pairCount = 0;
- for (int j = lane; j < halfDim; j += 32) {
- q0[pairCount] = q[j];
- q1[pairCount] = q[j + halfDim];
- cosStep[pairCount] = rope_cos_step_const[j];
- sinStep[pairCount] = rope_sin_step_const[j];
- const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
- const float2 ph = complex_pow_int(baseStep, queryPos);
- cosDelta[pairCount] = ph.x;
- sinDelta[pairCount] = ph.y;
- pairCount++;
- }
- float acc[8];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- acc[i] = 0.0f;
- }
- float m = -INFINITY;
- float l = 0.0f;
- for (int kv = 0; kv < effectiveLen; kv++) {
- const int blockIdxKV = kv / blockSize;
- const int blockOff = kv % blockSize;
- const __half* kBlock = reinterpret_cast<const __half*>(KBlocks[blockIdxKV]);
- const __half* k = kBlock + blockOff * kvStride + kvHead * headDim;
- float dot = 0.0f;
- #pragma unroll
- for (int pi = 0; pi < 4; pi++) {
- if (pi >= pairCount) {
- break;
- }
- const int j = lane + 32 * pi;
- if (j >= halfDim) {
- continue;
- }
- const float k0 = __half2float(k[j]);
- const float k1 = __half2float(k[j + halfDim]);
- const float a = fmaf(q0[pi], k0, q1[pi] * k1);
- const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
- dot = fmaf(cosDelta[pi], a, dot);
- dot = fmaf(sinDelta[pi], b, dot);
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- dot += __shfl_down_sync(0xffffffff, dot, offset);
- }
- dot = __shfl_sync(0xffffffff, dot, 0);
- // Advance delta -> delta-1 for next kv.
- #pragma unroll
- for (int pi = 0; pi < 4; pi++) {
- if (pi >= pairCount) {
- break;
- }
- rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
- }
- const float score = dot * scale;
- float alpha = 1.0f;
- float beta = 0.0f;
- if (lane == 0) {
- const float newM = fmaxf(m, score);
- alpha = __expf(m - newM);
- beta = __expf(score - newM);
- m = newM;
- l = l * alpha + beta;
- }
- alpha = __shfl_sync(0xffffffff, alpha, 0);
- beta = __shfl_sync(0xffffffff, beta, 0);
- l = __shfl_sync(0xffffffff, l, 0);
- const __half* vBlock = reinterpret_cast<const __half*>(VBlocks[blockIdxKV]);
- const __half* v = vBlock + blockOff * kvStride + kvHead * headDim;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
- }
- }
- }
- const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- o[d] = acc[i] * invL;
- }
- }
- }
- template <typename T>
- __device__ __forceinline__ float load_kv(const T* p, int idx) {
- return p[idx];
- }
- template <>
- __device__ __forceinline__ float load_kv<__half>(const __half* p, int idx) {
- return __half2float(p[idx]);
- }
- template <typename T, bool kUseRoPE = false>
- __global__ void paged_attention_split_kv_kernel(
- const float* Q,
- const T* const* KBlocks,
- const T* const* VBlocks,
- float* partialMax,
- float* partialSum,
- float* partialOut,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale, int startPos,
- int numSplits, int splitSize
- ) {
- const int seq = blockIdx.x;
- const int head = blockIdx.y;
- const int split = blockIdx.z;
- const int lane = threadIdx.x & 31;
- if (seq >= seqLen) {
- return;
- }
- // One warp per (seq, head, split).
- if (threadIdx.x >= 32) {
- return;
- }
- const int kvHead = head / (numHeads / numKVHeads);
- const float* q = Q + seq * numHeads * headDim + head * headDim;
- const int kvStride = numKVHeads * headDim;
- const int queryPos = startPos + seq;
- const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
- const int splitStart = split * splitSize;
- const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen;
- const size_t qh = (size_t)seq * (size_t)numHeads + (size_t)head;
- const size_t splitIdx = qh * (size_t)numSplits + (size_t)split;
- float* outVec = partialOut + splitIdx * (size_t)headDim;
- if (splitStart >= splitEnd) {
- if (lane == 0) {
- partialMax[splitIdx] = -INFINITY;
- partialSum[splitIdx] = 0.0f;
- }
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- outVec[d] = 0.0f;
- }
- }
- return;
- }
- const int halfDim = headDim >> 1;
- const int ropeExp = queryPos - splitStart;
- float qreg[8];
- float q0[4];
- float q1[4];
- float cosStep[4];
- float sinStep[4];
- float cosDelta[4];
- float sinDelta[4];
- int pairCount = 0;
- if (kUseRoPE) {
- for (int j = lane; j < halfDim; j += 32) {
- q0[pairCount] = q[j];
- q1[pairCount] = q[j + halfDim];
- cosStep[pairCount] = rope_cos_step_const[j];
- sinStep[pairCount] = rope_sin_step_const[j];
- const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
- const float2 ph = complex_pow_int(baseStep, ropeExp);
- cosDelta[pairCount] = ph.x;
- sinDelta[pairCount] = ph.y;
- pairCount++;
- }
- } else {
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- qreg[i] = (d < headDim) ? q[d] : 0.0f;
- }
- }
- float acc[8];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- acc[i] = 0.0f;
- }
- float m = -INFINITY;
- float l = 0.0f;
- for (int kv = splitStart; kv < splitEnd; kv++) {
- const int blockIdxKV = kv / blockSize;
- const int blockOff = kv % blockSize;
- const T* kBlock = KBlocks[blockIdxKV];
- const T* k = kBlock + blockOff * kvStride + kvHead * headDim;
- float dot = 0.0f;
- if (kUseRoPE) {
- #pragma unroll
- for (int pi = 0; pi < 4; pi++) {
- if (pi >= pairCount) {
- break;
- }
- const int j = lane + 32 * pi;
- if (j >= halfDim) {
- continue;
- }
- const float k0 = load_kv<T>(k, j);
- const float k1 = load_kv<T>(k, j + halfDim);
- const float a = fmaf(q0[pi], k0, q1[pi] * k1);
- const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
- dot = fmaf(cosDelta[pi], a, dot);
- dot = fmaf(sinDelta[pi], b, dot);
- }
- } else {
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- dot = fmaf(qreg[i], load_kv<T>(k, d), dot);
- }
- }
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- dot += __shfl_down_sync(0xffffffff, dot, offset);
- }
- dot = __shfl_sync(0xffffffff, dot, 0);
- if (kUseRoPE) {
- #pragma unroll
- for (int pi = 0; pi < 4; pi++) {
- if (pi >= pairCount) {
- break;
- }
- rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
- }
- }
- const float score = dot * scale;
- float alpha = 1.0f;
- float beta = 0.0f;
- if (lane == 0) {
- const float newM = fmaxf(m, score);
- alpha = __expf(m - newM);
- beta = __expf(score - newM);
- m = newM;
- l = l * alpha + beta;
- }
- alpha = __shfl_sync(0xffffffff, alpha, 0);
- beta = __shfl_sync(0xffffffff, beta, 0);
- const T* vBlock = VBlocks[blockIdxKV];
- const T* v = vBlock + blockOff * kvStride + kvHead * headDim;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- acc[i] = acc[i] * alpha + beta * load_kv<T>(v, d);
- }
- }
- }
- if (lane == 0) {
- partialMax[splitIdx] = m;
- partialSum[splitIdx] = l;
- }
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- outVec[d] = acc[i];
- }
- }
- }
- template <typename T, bool kUseRoPE = false>
- __global__ void paged_attention_split_kv_batch_kernel(
- const float* Q,
- const T* const* KBlocksFlat,
- const T* const* VBlocksFlat,
- const int* blockOffsets,
- const int* kvLens,
- const int* queryPos,
- float* partialMax,
- float* partialSum,
- float* partialOut,
- int numTokens,
- int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale,
- int numSplits, int splitSize
- ) {
- const int tok = blockIdx.x;
- const int head = blockIdx.y;
- const int split = blockIdx.z;
- const int lane = threadIdx.x & 31;
- if (tok >= numTokens) {
- return;
- }
- // One warp per (tok, head, split).
- if (threadIdx.x >= 32) {
- return;
- }
- const int kvHead = head / (numHeads / numKVHeads);
- const float* q = Q + tok * numHeads * headDim + head * headDim;
- const int kvLen = kvLens[tok];
- const int qPos = queryPos[tok];
- const int base = blockOffsets[tok];
- const int kvStride = numKVHeads * headDim;
- const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
- const int splitStart = split * splitSize;
- const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen;
- const size_t qh = (size_t)tok * (size_t)numHeads + (size_t)head;
- const size_t splitIdx = qh * (size_t)numSplits + (size_t)split;
- float* outVec = partialOut + splitIdx * (size_t)headDim;
- if (splitStart >= splitEnd) {
- if (lane == 0) {
- partialMax[splitIdx] = -INFINITY;
- partialSum[splitIdx] = 0.0f;
- }
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- outVec[d] = 0.0f;
- }
- }
- return;
- }
- const int halfDim = headDim >> 1;
- const int ropeExp = qPos - splitStart;
- float qreg[8];
- float q0[4];
- float q1[4];
- float cosStep[4];
- float sinStep[4];
- float cosDelta[4];
- float sinDelta[4];
- int pairCount = 0;
- if (kUseRoPE) {
- for (int j = lane; j < halfDim; j += 32) {
- q0[pairCount] = q[j];
- q1[pairCount] = q[j + halfDim];
- cosStep[pairCount] = rope_cos_step_const[j];
- sinStep[pairCount] = rope_sin_step_const[j];
- const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
- const float2 ph = complex_pow_int(baseStep, ropeExp);
- cosDelta[pairCount] = ph.x;
- sinDelta[pairCount] = ph.y;
- pairCount++;
- }
- } else {
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- qreg[i] = (d < headDim) ? q[d] : 0.0f;
- }
- }
- float acc[8];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- acc[i] = 0.0f;
- }
- float m = -INFINITY;
- float l = 0.0f;
- for (int kv = splitStart; kv < splitEnd; kv++) {
- const int bidx = kv / blockSize;
- const int boff = kv % blockSize;
- const T* kBlock = KBlocksFlat[base + bidx];
- const T* k = kBlock + boff * kvStride + kvHead * headDim;
- float dot = 0.0f;
- if (kUseRoPE) {
- #pragma unroll
- for (int pi = 0; pi < 4; pi++) {
- if (pi >= pairCount) {
- break;
- }
- const int j = lane + 32 * pi;
- if (j >= halfDim) {
- continue;
- }
- const float k0 = load_kv<T>(k, j);
- const float k1 = load_kv<T>(k, j + halfDim);
- const float a = fmaf(q0[pi], k0, q1[pi] * k1);
- const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
- dot = fmaf(cosDelta[pi], a, dot);
- dot = fmaf(sinDelta[pi], b, dot);
- }
- } else {
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- dot = fmaf(qreg[i], load_kv<T>(k, d), dot);
- }
- }
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- dot += __shfl_down_sync(0xffffffff, dot, offset);
- }
- dot = __shfl_sync(0xffffffff, dot, 0);
- if (kUseRoPE) {
- #pragma unroll
- for (int pi = 0; pi < 4; pi++) {
- if (pi >= pairCount) {
- break;
- }
- rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
- }
- }
- const float score = dot * scale;
- float alpha = 1.0f;
- float beta = 0.0f;
- if (lane == 0) {
- const float newM = fmaxf(m, score);
- alpha = __expf(m - newM);
- beta = __expf(score - newM);
- m = newM;
- l = l * alpha + beta;
- }
- alpha = __shfl_sync(0xffffffff, alpha, 0);
- beta = __shfl_sync(0xffffffff, beta, 0);
- const T* vBlock = VBlocksFlat[base + bidx];
- const T* v = vBlock + boff * kvStride + kvHead * headDim;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- acc[i] = acc[i] * alpha + beta * load_kv<T>(v, d);
- }
- }
- }
- if (lane == 0) {
- partialMax[splitIdx] = m;
- partialSum[splitIdx] = l;
- }
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- outVec[d] = acc[i];
- }
- }
- }
- __global__ void paged_attention_split_kv_reduce_kernel(
- const float* partialMax,
- const float* partialSum,
- const float* partialOut,
- float* out,
- int queryCount,
- int numHeads,
- int headDim,
- int numSplits
- ) {
- const int q = blockIdx.x;
- const int head = blockIdx.y;
- const int lane = threadIdx.x & 31;
- if (q >= queryCount) {
- return;
- }
- // One warp per (q, head).
- if (threadIdx.x >= 32) {
- return;
- }
- const size_t qh = (size_t)q * (size_t)numHeads + (size_t)head;
- const size_t base = qh * (size_t)numSplits;
- float gmax = -INFINITY;
- if (lane == 0) {
- for (int s = 0; s < numSplits; s++) {
- gmax = fmaxf(gmax, partialMax[base + (size_t)s]);
- }
- }
- gmax = __shfl_sync(0xffffffff, gmax, 0);
- if (gmax == -INFINITY) {
- const size_t outBase = qh * (size_t)headDim;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- out[outBase + (size_t)d] = 0.0f;
- }
- }
- return;
- }
- float sum = 0.0f;
- if (lane == 0) {
- for (int s = 0; s < numSplits; s++) {
- const float l = partialSum[base + (size_t)s];
- sum += l * __expf(partialMax[base + (size_t)s] - gmax);
- }
- }
- sum = __shfl_sync(0xffffffff, sum, 0);
- const float invSum = (sum > 0.0f) ? (1.0f / sum) : 0.0f;
- float acc[8];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- acc[i] = 0.0f;
- }
- for (int s = 0; s < numSplits; s++) {
- float scale = 0.0f;
- if (lane == 0) {
- scale = __expf(partialMax[base + (size_t)s] - gmax);
- }
- scale = __shfl_sync(0xffffffff, scale, 0);
- const float* vec = partialOut + (base + (size_t)s) * (size_t)headDim;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- acc[i] = fmaf(scale, vec[d], acc[i]);
- }
- }
- }
- const size_t outBase = qh * (size_t)headDim;
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int d = lane + 32 * i;
- if (d < headDim) {
- out[outBase + (size_t)d] = acc[i] * invSum;
- }
- }
- }
- int cuda_rmsnorm_f32(float* x, const float* w, int seqLen, int dim, float eps) {
- int threads = 256;
- rmsnorm_kernel<<<seqLen, threads>>>(x, w, dim, eps);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- int cuda_cast_f32_to_f16(const float* src, unsigned short* dst, int n) {
- int threads = 256;
- int blocks = (n + threads - 1) / threads;
- cast_f32_to_f16_kernel<<<blocks, threads>>>(src, reinterpret_cast<__half*>(dst), n);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- int cuda_paged_attention_f32_f16kv(
- const float* Q,
- const unsigned short* const* KBlocks,
- const unsigned short* const* VBlocks,
- float* out,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale, int startPos
- ) {
- // Split-KV Flash Decoding for long contexts.
- const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
- const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
- const int qhCount = seqLen * numHeads;
- const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
- if (!useSplit) {
- dim3 blocks(seqLen, numHeads);
- int threads = 32;
- paged_attention_kernel_f16kv<<<blocks, threads>>>(
- Q, KBlocks, VBlocks, out,
- seqLen, kvLen, numHeads, numKVHeads, headDim,
- blockSize, scale, startPos
- );
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
- const size_t totalFloats = splitCount * (size_t)(headDim + 2);
- float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
- if (buf == nullptr) {
- return 1;
- }
- float* partialMax = buf;
- float* partialSum = partialMax + splitCount;
- float* partialOut = partialSum + splitCount;
- dim3 blocks1(seqLen, numHeads, numSplits);
- paged_attention_split_kv_kernel<__half><<<blocks1, 32>>>(
- Q,
- reinterpret_cast<const __half* const*>(KBlocks),
- reinterpret_cast<const __half* const*>(VBlocks),
- partialMax, partialSum, partialOut,
- seqLen, kvLen, numHeads, numKVHeads, headDim,
- blockSize,
- scale, startPos,
- numSplits, kPagedAttentionSplitSize
- );
- CHECK_CUDA(cudaGetLastError());
- dim3 blocks2(seqLen, numHeads);
- paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
- partialMax, partialSum, partialOut,
- out,
- seqLen, numHeads, headDim, numSplits
- );
- CHECK_CUDA(cudaGetLastError());
- cuda_free(buf);
- return 0;
- }
- int cuda_paged_attention_rope_f32_f16kv(
- const float* Q,
- const unsigned short* const* KBlocks,
- const unsigned short* const* VBlocks,
- float* out,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale, int startPos,
- float theta
- ) {
- if (ensure_rope_step_table(headDim, theta) != 0) {
- return 1;
- }
- // Split-KV Flash Decoding for long contexts.
- const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
- const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
- const int qhCount = seqLen * numHeads;
- const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
- if (!useSplit) {
- dim3 blocks(seqLen, numHeads);
- int threads = 32;
- paged_attention_kernel_f16kv_rope<<<blocks, threads>>>(
- Q, KBlocks, VBlocks, out,
- seqLen, kvLen, numHeads, numKVHeads, headDim,
- blockSize, scale, startPos
- );
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
- const size_t totalFloats = splitCount * (size_t)(headDim + 2);
- float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
- if (buf == nullptr) {
- return 1;
- }
- float* partialMax = buf;
- float* partialSum = partialMax + splitCount;
- float* partialOut = partialSum + splitCount;
- dim3 blocks1(seqLen, numHeads, numSplits);
- paged_attention_split_kv_kernel<__half, true><<<blocks1, 32>>>(
- Q,
- reinterpret_cast<const __half* const*>(KBlocks),
- reinterpret_cast<const __half* const*>(VBlocks),
- partialMax, partialSum, partialOut,
- seqLen, kvLen, numHeads, numKVHeads, headDim,
- blockSize,
- scale, startPos,
- numSplits, kPagedAttentionSplitSize
- );
- CHECK_CUDA(cudaGetLastError());
- dim3 blocks2(seqLen, numHeads);
- paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
- partialMax, partialSum, partialOut,
- out,
- seqLen, numHeads, headDim, numSplits
- );
- CHECK_CUDA(cudaGetLastError());
- cuda_free(buf);
- return 0;
- }
- int cuda_paged_attention_batch_f32_f16kv(
- const float* Q,
- const unsigned short* const* KBlocksFlat,
- const unsigned short* const* VBlocksFlat,
- const int* blockOffsets,
- const int* kvLens,
- const int* queryPos,
- float* out,
- int numTokens,
- int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale,
- int maxKvLen
- ) {
- const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
- const int qhCount = numTokens * numHeads;
- const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
- if (!useSplit) {
- dim3 blocks(numTokens, numHeads);
- int threads = 32;
- paged_attention_batch_kernel_f16kv<<<blocks, threads>>>(
- Q, KBlocksFlat, VBlocksFlat,
- blockOffsets, kvLens, queryPos,
- out,
- numTokens,
- numHeads, numKVHeads, headDim,
- blockSize,
- scale
- );
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
- const size_t totalFloats = splitCount * (size_t)(headDim + 2);
- float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
- if (buf == nullptr) {
- return 1;
- }
- float* partialMax = buf;
- float* partialSum = partialMax + splitCount;
- float* partialOut = partialSum + splitCount;
- dim3 blocks1(numTokens, numHeads, numSplits);
- paged_attention_split_kv_batch_kernel<__half><<<blocks1, 32>>>(
- Q,
- reinterpret_cast<const __half* const*>(KBlocksFlat),
- reinterpret_cast<const __half* const*>(VBlocksFlat),
- blockOffsets, kvLens, queryPos,
- partialMax, partialSum, partialOut,
- numTokens,
- numHeads, numKVHeads, headDim,
- blockSize,
- scale,
- numSplits, kPagedAttentionSplitSize
- );
- CHECK_CUDA(cudaGetLastError());
- dim3 blocks2(numTokens, numHeads);
- paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
- partialMax, partialSum, partialOut,
- out,
- numTokens, numHeads, headDim, numSplits
- );
- CHECK_CUDA(cudaGetLastError());
- cuda_free(buf);
- return 0;
- }
- int cuda_paged_attention_rope_batch_f32_f16kv(
- const float* Q,
- const unsigned short* const* KBlocksFlat,
- const unsigned short* const* VBlocksFlat,
- const int* blockOffsets,
- const int* kvLens,
- const int* queryPos,
- float* out,
- int numTokens,
- int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale,
- int maxKvLen,
- float theta
- ) {
- if (ensure_rope_step_table(headDim, theta) != 0) {
- return 1;
- }
- const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
- const int qhCount = numTokens * numHeads;
- const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
- if (!useSplit) {
- dim3 blocks(numTokens, numHeads);
- int threads = 32;
- paged_attention_batch_kernel_f16kv_rope<<<blocks, threads>>>(
- Q, KBlocksFlat, VBlocksFlat,
- blockOffsets, kvLens, queryPos,
- out,
- numTokens,
- numHeads, numKVHeads, headDim,
- blockSize,
- scale
- );
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
- const size_t totalFloats = splitCount * (size_t)(headDim + 2);
- float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
- if (buf == nullptr) {
- return 1;
- }
- float* partialMax = buf;
- float* partialSum = partialMax + splitCount;
- float* partialOut = partialSum + splitCount;
- dim3 blocks1(numTokens, numHeads, numSplits);
- paged_attention_split_kv_batch_kernel<__half, true><<<blocks1, 32>>>(
- Q,
- reinterpret_cast<const __half* const*>(KBlocksFlat),
- reinterpret_cast<const __half* const*>(VBlocksFlat),
- blockOffsets, kvLens, queryPos,
- partialMax, partialSum, partialOut,
- numTokens,
- numHeads, numKVHeads, headDim,
- blockSize,
- scale,
- numSplits, kPagedAttentionSplitSize
- );
- CHECK_CUDA(cudaGetLastError());
- dim3 blocks2(numTokens, numHeads);
- paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
- partialMax, partialSum, partialOut,
- out,
- numTokens, numHeads, headDim, numSplits
- );
- CHECK_CUDA(cudaGetLastError());
- cuda_free(buf);
- return 0;
- }
- __global__ void paged_attention_batch_kernel(
- const float* Q,
- const float* const* KBlocksFlat,
- const float* const* VBlocksFlat,
- const int* blockOffsets,
- const int* kvLens,
- const int* queryPos,
- float* out,
- int numTokens,
- int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale
- );
- __global__ void paged_attention_batch_kernel_f16kv(
- const float* Q,
- const unsigned short* const* KBlocksFlat,
- const unsigned short* const* VBlocksFlat,
- const int* blockOffsets,
- const int* kvLens,
- const int* queryPos,
- float* out,
- int numTokens,
- int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale
- );
- int cuda_paged_attention_batch_f32(
- const float* Q,
- const float* const* KBlocksFlat,
- const float* const* VBlocksFlat,
- const int* blockOffsets,
- const int* kvLens,
- const int* queryPos,
- float* out,
- int numTokens,
- int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale,
- int maxKvLen
- ) {
- const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
- const int qhCount = numTokens * numHeads;
- const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
- if (!useSplit) {
- dim3 blocks(numTokens, numHeads);
- int threads = 256;
- paged_attention_batch_kernel<<<blocks, threads>>>(
- Q, KBlocksFlat, VBlocksFlat,
- blockOffsets, kvLens, queryPos,
- out,
- numTokens,
- numHeads, numKVHeads, headDim,
- blockSize,
- scale
- );
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
- const size_t totalFloats = splitCount * (size_t)(headDim + 2);
- float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
- if (buf == nullptr) {
- return 1;
- }
- float* partialMax = buf;
- float* partialSum = partialMax + splitCount;
- float* partialOut = partialSum + splitCount;
- dim3 blocks1(numTokens, numHeads, numSplits);
- paged_attention_split_kv_batch_kernel<float><<<blocks1, 32>>>(
- Q,
- KBlocksFlat,
- VBlocksFlat,
- blockOffsets, kvLens, queryPos,
- partialMax, partialSum, partialOut,
- numTokens,
- numHeads, numKVHeads, headDim,
- blockSize,
- scale,
- numSplits, kPagedAttentionSplitSize
- );
- CHECK_CUDA(cudaGetLastError());
- dim3 blocks2(numTokens, numHeads);
- paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
- partialMax, partialSum, partialOut,
- out,
- numTokens, numHeads, headDim, numSplits
- );
- CHECK_CUDA(cudaGetLastError());
- cuda_free(buf);
- return 0;
- }
- // RoPE kernel
- __global__ void rope_kernel(float* x, const int* positions, int totalDim, int headDim, float theta) {
- int seq = blockIdx.x;
- int pos = positions[seq];
- float* rowData = x + seq * totalDim;
- int halfDim = headDim / 2;
-
- // Each thread handles one (j, j+halfDim) pair across all heads.
- // Compute sin/cos once per j and reuse across heads.
- for (int j = threadIdx.x; j < halfDim; j += blockDim.x) {
- const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim);
- const float freq = pos * invFreq;
- float sinF, cosF;
- sincosf(freq, &sinF, &cosF);
- for (int headStart = 0; headStart < totalDim; headStart += headDim) {
- const int idx0 = headStart + j;
- const int idx1 = idx0 + halfDim;
- const float v0 = rowData[idx0];
- const float v1 = rowData[idx1];
- rowData[idx0] = v0 * cosF - v1 * sinF;
- rowData[idx1] = v1 * cosF + v0 * sinF;
- }
- }
- }
- int cuda_rope_f32(float* x, const int* positions, int seqLen, int numHeads, int headDim, float theta) {
- int threads = (headDim / 2) < 128 ? (headDim / 2) : 128;
- rope_kernel<<<seqLen, threads>>>(x, positions, numHeads * headDim, headDim, theta);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // Optimized RoPE kernel for single position (scalar pos)
- __global__ void rope_kernel_single(float* x, int pos, int totalDim, int headDim, float theta) {
- // seq is always 0 for single token
- float* rowData = x; // x + 0 * totalDim
- int halfDim = headDim / 2;
-
- // Each thread handles one (j, j+halfDim) pair across all heads.
- // Compute sin/cos once per j and reuse across heads.
- for (int j = threadIdx.x; j < halfDim; j += blockDim.x) {
- const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim);
- const float freq = pos * invFreq;
- float sinF, cosF;
- sincosf(freq, &sinF, &cosF);
- for (int headStart = 0; headStart < totalDim; headStart += headDim) {
- const int idx0 = headStart + j;
- const int idx1 = idx0 + halfDim;
- const float v0 = rowData[idx0];
- const float v1 = rowData[idx1];
- rowData[idx0] = v0 * cosF - v1 * sinF;
- rowData[idx1] = v1 * cosF + v0 * sinF;
- }
- }
- }
- int cuda_rope_f32_single(float* x, int pos, int numHeads, int headDim, float theta) {
- int threads = (headDim / 2) < 128 ? (headDim / 2) : 128;
- rope_kernel_single<<<1, threads>>>(x, pos, numHeads * headDim, headDim, theta);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // Softmax kernel: one block per row
- __global__ void softmax_kernel(float* x, int cols) {
- int row = blockIdx.x;
- float* rowData = x + row * cols;
-
- __shared__ float smax[256];
- __shared__ float ssum[256];
-
- // Find max
- float threadMax = -INFINITY;
- for (int i = threadIdx.x; i < cols; i += blockDim.x) {
- threadMax = fmaxf(threadMax, rowData[i]);
- }
- smax[threadIdx.x] = threadMax;
- __syncthreads();
-
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (threadIdx.x < s) {
- smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
- }
- __syncthreads();
- }
- float maxVal = smax[0];
-
- // Compute exp and sum
- float threadSum = 0.0f;
- for (int i = threadIdx.x; i < cols; i += blockDim.x) {
- float val = expf(rowData[i] - maxVal);
- rowData[i] = val;
- threadSum += val;
- }
- ssum[threadIdx.x] = threadSum;
- __syncthreads();
-
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (threadIdx.x < s) {
- ssum[threadIdx.x] += ssum[threadIdx.x + s];
- }
- __syncthreads();
- }
- float sum = ssum[0];
-
- // Normalize
- for (int i = threadIdx.x; i < cols; i += blockDim.x) {
- rowData[i] /= sum;
- }
- }
- int cuda_softmax_f32(float* x, int rows, int cols) {
- int threads = 256;
- softmax_kernel<<<rows, threads>>>(x, cols);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // Top-K Logits Selection (for sampling without full D2H)
- // ============================================================
- #define TOPK_THREADS 256
- #define TOPK_LOCAL 8
- #define TOPK_SEGMENT (TOPK_THREADS * TOPK_LOCAL) // 2048
- #define TOPK_MAX_K 64
- static __device__ __forceinline__ float apply_rep_penalty(float score, bool hit, float penalty) {
- if (!hit) return score;
- if (score > 0.0f) return score / penalty;
- return score * penalty;
- }
- __global__ void topk_logits_kernel(
- const float* logits, int vocab,
- const int* rep_ids, int rep_count, float rep_penalty,
- int k,
- int* out_ids, float* out_scores
- ) {
- const int blockStart = blockIdx.x * TOPK_SEGMENT;
- const int tid = threadIdx.x;
- if (tid >= TOPK_THREADS) return;
- float localScores[TOPK_LOCAL];
- int localIds[TOPK_LOCAL];
- #pragma unroll
- for (int i = 0; i < TOPK_LOCAL; i++) {
- localScores[i] = -INFINITY;
- localIds[i] = -1;
- }
- // Each thread processes up to TOPK_LOCAL elements in a contiguous segment.
- #pragma unroll
- for (int j = 0; j < TOPK_LOCAL; j++) {
- int idx = blockStart + tid + j * TOPK_THREADS;
- if (idx >= vocab) break;
- float score = logits[idx];
- bool hit = false;
- // rep_count is small (<=64). Linear scan is fine.
- for (int r = 0; r < rep_count; r++) {
- if (rep_ids[r] == idx) {
- hit = true;
- break;
- }
- }
- score = apply_rep_penalty(score, hit, rep_penalty);
- // Insert into local top list (descending)
- int pos = TOPK_LOCAL;
- for (int t = 0; t < TOPK_LOCAL; t++) {
- if (score > localScores[t]) { pos = t; break; }
- }
- if (pos < TOPK_LOCAL) {
- for (int t = TOPK_LOCAL - 1; t > pos; t--) {
- localScores[t] = localScores[t-1];
- localIds[t] = localIds[t-1];
- }
- localScores[pos] = score;
- localIds[pos] = idx;
- }
- }
- __shared__ float shScores[TOPK_SEGMENT];
- __shared__ int shIds[TOPK_SEGMENT];
- // Write local results to shared candidate pool.
- #pragma unroll
- for (int j = 0; j < TOPK_LOCAL; j++) {
- int out = tid * TOPK_LOCAL + j;
- shScores[out] = localScores[j];
- shIds[out] = localIds[j];
- }
- __syncthreads();
- if (tid == 0) {
- // Block-level exact top-k from TOPK_SEGMENT candidates.
- if (k > TOPK_MAX_K) k = TOPK_MAX_K;
- float bestScores[TOPK_MAX_K];
- int bestIds[TOPK_MAX_K];
- for (int i = 0; i < k; i++) {
- bestScores[i] = -INFINITY;
- bestIds[i] = -1;
- }
- for (int i = 0; i < TOPK_SEGMENT; i++) {
- float score = shScores[i];
- int id = shIds[i];
- if (id < 0) continue;
- // Insert into best (descending)
- if (score <= bestScores[k-1]) continue;
- int pos = k;
- for (int t = 0; t < k; t++) {
- if (score > bestScores[t]) { pos = t; break; }
- }
- if (pos < k) {
- for (int t = k - 1; t > pos; t--) {
- bestScores[t] = bestScores[t-1];
- bestIds[t] = bestIds[t-1];
- }
- bestScores[pos] = score;
- bestIds[pos] = id;
- }
- }
- int base = blockIdx.x * k;
- for (int i = 0; i < k; i++) {
- out_ids[base + i] = bestIds[i];
- out_scores[base + i] = bestScores[i];
- }
- }
- }
- int cuda_topk_logits_f32(
- const float* logits, int vocab,
- const int* rep_ids, int rep_count, float rep_penalty,
- int k,
- int* out_ids, float* out_scores
- ) {
- if (k <= 0) return 0;
- if (k > TOPK_MAX_K) k = TOPK_MAX_K;
- int blocks = (vocab + TOPK_SEGMENT - 1) / TOPK_SEGMENT;
- dim3 grid(blocks);
- dim3 block(TOPK_THREADS);
- topk_logits_kernel<<<grid, block>>>(logits, vocab, rep_ids, rep_count, rep_penalty, k, out_ids, out_scores);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // Attention Kernel
- // Computes: softmax(Q @ K.T / scale + causal_mask) @ V
- // ============================================================
- __global__ void attention_kernel(
- const float* Q, const float* K, const float* V, float* out,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- float scale, int startPos
- ) {
- // Each block handles one (seq, head) pair
- int seq = blockIdx.x;
- int head = blockIdx.y;
-
- int kvHead = head / (numHeads / numKVHeads); // GQA support
-
- const float* q = Q + seq * numHeads * headDim + head * headDim;
- float* o = out + seq * numHeads * headDim + head * headDim;
-
- // Shared memory for attention scores
- extern __shared__ float shared[];
- float* scores = shared; // [kvLen]
-
- // Compute Q @ K.T for this head
- for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
- const float* k = K + kv * numKVHeads * headDim + kvHead * headDim;
-
- float dot = 0.0f;
- for (int d = 0; d < headDim; d++) {
- dot += q[d] * k[d];
- }
-
- // Apply causal mask
- int queryPos = startPos + seq;
- int keyPos = kv;
- if (keyPos > queryPos) {
- dot = -INFINITY;
- }
-
- scores[kv] = dot * scale;
- }
- __syncthreads();
-
- // Softmax over scores
- // Find max
- float maxVal = -INFINITY;
- for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
- maxVal = fmaxf(maxVal, scores[kv]);
- }
-
- // Reduce max across threads
- __shared__ float smax[256];
- smax[threadIdx.x] = maxVal;
- __syncthreads();
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (threadIdx.x < s) {
- smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
- }
- __syncthreads();
- }
- maxVal = smax[0];
-
- // Exp and sum
- float sum = 0.0f;
- for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
- float val = expf(scores[kv] - maxVal);
- scores[kv] = val;
- sum += val;
- }
-
- __shared__ float ssum[256];
- ssum[threadIdx.x] = sum;
- __syncthreads();
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (threadIdx.x < s) {
- ssum[threadIdx.x] += ssum[threadIdx.x + s];
- }
- __syncthreads();
- }
- sum = ssum[0];
-
- // Normalize
- for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
- scores[kv] /= sum;
- }
- __syncthreads();
-
- // Compute weighted sum of V
- for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
- float val = 0.0f;
- for (int kv = 0; kv < kvLen; kv++) {
- const float* v = V + kv * numKVHeads * headDim + kvHead * headDim;
- val += scores[kv] * v[d];
- }
- o[d] = val;
- }
- }
- __global__ void paged_attention_kernel(
- const float* Q,
- const float* const* KBlocks,
- const float* const* VBlocks,
- float* out,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale, int startPos
- ) {
- int seq = blockIdx.x;
- int head = blockIdx.y;
- int kvHead = head / (numHeads / numKVHeads);
- const float* q = Q + seq * numHeads * headDim + head * headDim;
- float* o = out + seq * numHeads * headDim + head * headDim;
- const int kvStride = numKVHeads * headDim;
- int queryPos = startPos + seq;
- // Pass 1: max
- float localMax = -INFINITY;
- for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
- int keyPos = kv;
- if (keyPos > queryPos) {
- continue;
- }
- const int blockIdxKV = kv / blockSize;
- const int blockOff = kv % blockSize;
- const float* kBlock = KBlocks[blockIdxKV];
- const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
- float dot = 0.0f;
- for (int d = 0; d < headDim; d++) {
- dot += q[d] * k[d];
- }
- float score = dot * scale;
- localMax = fmaxf(localMax, score);
- }
- __shared__ float smax[256];
- smax[threadIdx.x] = localMax;
- __syncthreads();
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (threadIdx.x < s) {
- smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
- }
- __syncthreads();
- }
- float maxVal = smax[0];
- // Pass 2: sum exp
- float localSum = 0.0f;
- for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
- int keyPos = kv;
- if (keyPos > queryPos) {
- continue;
- }
- const int blockIdxKV = kv / blockSize;
- const int blockOff = kv % blockSize;
- const float* kBlock = KBlocks[blockIdxKV];
- const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
- float dot = 0.0f;
- for (int d = 0; d < headDim; d++) {
- dot += q[d] * k[d];
- }
- float score = dot * scale;
- localSum += expf(score - maxVal);
- }
- __shared__ float ssum[256];
- ssum[threadIdx.x] = localSum;
- __syncthreads();
- for (int s = blockDim.x / 2; s > 0; s >>= 1) {
- if (threadIdx.x < s) {
- ssum[threadIdx.x] += ssum[threadIdx.x + s];
- }
- __syncthreads();
- }
- float sumVal = ssum[0];
- float invSum = (sumVal > 0.0f) ? (1.0f / sumVal) : 0.0f;
- // Pass 3: output accumulation
- for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
- float outVal = 0.0f;
- for (int kv = 0; kv < kvLen; kv++) {
- int keyPos = kv;
- if (keyPos > queryPos) {
- break;
- }
- const int blockIdxKV = kv / blockSize;
- const int blockOff = kv % blockSize;
- const float* kBlock = KBlocks[blockIdxKV];
- const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
- float dot = 0.0f;
- for (int kd = 0; kd < headDim; kd++) {
- dot += q[kd] * k[kd];
- }
- float score = dot * scale;
- float w = expf(score - maxVal) * invSum;
- const float* vBlock = VBlocks[blockIdxKV];
- const float* v = vBlock + blockOff * kvStride + kvHead * headDim;
- outVal += w * v[d];
- }
- o[d] = outVal;
- }
- }
- __global__ void paged_attention_batch_kernel(
- const float* Q,
- const float* const* KBlocksFlat,
- const float* const* VBlocksFlat,
- const int* blockOffsets,
- const int* kvLens,
- const int* queryPos,
- float* out,
- int numTokens,
- int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale
- ) {
- int tok = blockIdx.x;
- int head = blockIdx.y;
- if (tok >= numTokens) {
- return;
- }
- int kvHead = head / (numHeads / numKVHeads);
- const float* q = Q + tok * numHeads * headDim + head * headDim;
- float* o = out + tok * numHeads * headDim + head * headDim;
- int kvLen = kvLens[tok];
- int qPos = queryPos[tok];
- int base = blockOffsets[tok];
- const int kvStride = numKVHeads * headDim;
- const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
- float acc = 0.0f;
- if (threadIdx.x >= headDim) {
- acc = 0.0f;
- }
- __shared__ float m;
- __shared__ float l;
- __shared__ float alpha;
- __shared__ float beta;
- __shared__ float dotShared;
- if (threadIdx.x == 0) {
- m = -INFINITY;
- l = 0.0f;
- }
- __syncthreads();
- for (int kv = 0; kv < effectiveLen; kv++) {
- const int bidx = kv / blockSize;
- const int boff = kv % blockSize;
- const float* kBlock = KBlocksFlat[base + bidx];
- const float* k = kBlock + boff * kvStride + kvHead * headDim;
- float partial = 0.0f;
- for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
- partial = fmaf(q[d], k[d], partial);
- }
- // block reduction (sum)
- for (int offset = 16; offset > 0; offset >>= 1) {
- partial += __shfl_down_sync(0xffffffff, partial, offset);
- }
- __shared__ float warpSum[8];
- int lane = threadIdx.x & 31;
- int warp = threadIdx.x >> 5;
- if (lane == 0) {
- warpSum[warp] = partial;
- }
- __syncthreads();
- if (warp == 0) {
- float v = (lane < 8) ? warpSum[lane] : 0.0f;
- for (int offset = 16; offset > 0; offset >>= 1) {
- v += __shfl_down_sync(0xffffffff, v, offset);
- }
- if (lane == 0) {
- dotShared = v;
- }
- }
- __syncthreads();
- float score = dotShared * scale;
- if (threadIdx.x == 0) {
- float newM = fmaxf(m, score);
- float a = expf(m - newM);
- float b = expf(score - newM);
- m = newM;
- l = l * a + b;
- alpha = a;
- beta = b;
- }
- __syncthreads();
- if (threadIdx.x < headDim) {
- const float* vBlock = VBlocksFlat[base + bidx];
- const float* v = vBlock + boff * kvStride + kvHead * headDim;
- acc = fmaf(beta, v[threadIdx.x], acc * alpha);
- }
- __syncthreads();
- }
- if (threadIdx.x < headDim) {
- float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
- o[threadIdx.x] = acc * invL;
- }
- }
- int cuda_attention_f32(
- const float* Q, const float* K, const float* V, float* out,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- float scale, int startPos
- ) {
- dim3 blocks(seqLen, numHeads);
- int threads = 256;
- size_t sharedMem = kvLen * sizeof(float);
-
- attention_kernel<<<blocks, threads, sharedMem>>>(
- Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos
- );
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- int cuda_paged_attention_f32(
- const float* Q,
- const float* const* KBlocks,
- const float* const* VBlocks,
- float* out,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- int blockSize,
- float scale, int startPos
- ) {
- // Split-KV Flash Decoding for long contexts.
- const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
- const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
- const int qhCount = seqLen * numHeads;
- const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
- if (!useSplit) {
- dim3 blocks(seqLen, numHeads);
- int threads = 256;
- paged_attention_kernel<<<blocks, threads>>>(
- Q, KBlocks, VBlocks, out,
- seqLen, kvLen, numHeads, numKVHeads, headDim,
- blockSize, scale, startPos
- );
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
- const size_t totalFloats = splitCount * (size_t)(headDim + 2);
- float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
- if (buf == nullptr) {
- return 1;
- }
- float* partialMax = buf;
- float* partialSum = partialMax + splitCount;
- float* partialOut = partialSum + splitCount;
- dim3 blocks1(seqLen, numHeads, numSplits);
- paged_attention_split_kv_kernel<float><<<blocks1, 32>>>(
- Q, KBlocks, VBlocks,
- partialMax, partialSum, partialOut,
- seqLen, kvLen, numHeads, numKVHeads, headDim,
- blockSize,
- scale, startPos,
- numSplits, kPagedAttentionSplitSize
- );
- CHECK_CUDA(cudaGetLastError());
- dim3 blocks2(seqLen, numHeads);
- paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
- partialMax, partialSum, partialOut,
- out,
- seqLen, numHeads, headDim, numSplits
- );
- CHECK_CUDA(cudaGetLastError());
- cuda_free(buf);
- return 0;
- }
- int cuda_attention_f32_timed(
- const float* Q, const float* K, const float* V, float* out,
- int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
- float scale, int startPos, float* ms
- ) {
- cudaEvent_t evStart;
- cudaEvent_t evStop;
- CHECK_CUDA(cudaEventCreate(&evStart));
- CHECK_CUDA(cudaEventCreate(&evStop));
- dim3 blocks(seqLen, numHeads);
- int threads = 256;
- size_t sharedMem = kvLen * sizeof(float);
- CHECK_CUDA(cudaEventRecord(evStart));
- attention_kernel<<<blocks, threads, sharedMem>>>(
- Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos
- );
- CHECK_CUDA(cudaEventRecord(evStop));
- CHECK_CUDA(cudaEventSynchronize(evStop));
- float elapsed = 0.0f;
- CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
- if (ms != NULL) {
- *ms = elapsed;
- }
- CHECK_CUDA(cudaEventDestroy(evStart));
- CHECK_CUDA(cudaEventDestroy(evStop));
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // Debug helper
- int cuda_print_struct_sizes() {
- printf("GPU Struct Sizes:\n");
- printf("BlockQ2_K: %lu\n", sizeof(BlockQ2_K));
- printf("BlockQ3_K: %lu\n", sizeof(BlockQ3_K));
- printf("BlockQ4_K: %lu\n", sizeof(BlockQ4_K));
- printf("BlockQ6_K: %lu\n", sizeof(BlockQ6_K));
- printf("BlockQ8_K: %lu\n", sizeof(BlockQ8_K));
- return 0;
- }
|