#include "cuda_common.cuh" #include #include #include namespace { constexpr int kPagedAttentionSplitSize = 1024; constexpr int kPagedAttentionSplitQHThreshold = 4096; // queryCount*numHeads threshold } // namespace // ============================================================ // Fused RoPE helpers (constant step table + fast complex pow) // ============================================================ // Stores cos/sin of per-dimension "step" angle (invFreq) for RoPE. // Only indices [0, headDim/2) are used. Max supported headDim is 256. __device__ __constant__ float rope_cos_step_const[128]; __device__ __constant__ float rope_sin_step_const[128]; static int g_rope_step_inited[32] = {0}; static int g_rope_step_head_dim[32] = {0}; static uint32_t g_rope_step_theta_bits[32] = {0}; static int ensure_rope_step_table(int headDim, float theta) { if (headDim <= 0 || headDim > 256 || (headDim & 1) != 0) { return 1; } int dev = 0; CHECK_CUDA(cudaGetDevice(&dev)); if (dev < 0 || dev >= 32) { dev = 0; } uint32_t thetaBits = 0; memcpy(&thetaBits, &theta, sizeof(thetaBits)); if (g_rope_step_inited[dev] && g_rope_step_head_dim[dev] == headDim && g_rope_step_theta_bits[dev] == thetaBits) { return 0; } float cosStep[128]; float sinStep[128]; const int halfDim = headDim / 2; for (int j = 0; j < 128; j++) { if (j < halfDim) { // invFreq = theta^(-2j/headDim) const double exp = -2.0 * (double)j / (double)headDim; const double invFreq = pow((double)theta, exp); cosStep[j] = (float)cos(invFreq); sinStep[j] = (float)sin(invFreq); } else { cosStep[j] = 1.0f; sinStep[j] = 0.0f; } } CHECK_CUDA(cudaMemcpyToSymbol(rope_cos_step_const, cosStep, sizeof(cosStep), 0, cudaMemcpyHostToDevice)); CHECK_CUDA(cudaMemcpyToSymbol(rope_sin_step_const, sinStep, sizeof(sinStep), 0, cudaMemcpyHostToDevice)); g_rope_step_inited[dev] = 1; g_rope_step_head_dim[dev] = headDim; g_rope_step_theta_bits[dev] = thetaBits; return 0; } __device__ __forceinline__ float2 complex_mul_f2(float2 a, float2 b) { // (a.x + i a.y) * (b.x + i b.y) return make_float2( fmaf(a.x, b.x, -a.y * b.y), fmaf(a.x, b.y, a.y * b.x) ); } __device__ __forceinline__ float2 complex_pow_int(float2 base, int exp) { float2 result = make_float2(1.0f, 0.0f); float2 b = base; int e = exp; while (e > 0) { if (e & 1) { result = complex_mul_f2(result, b); } b = complex_mul_f2(b, b); e >>= 1; } return result; } __device__ __forceinline__ void rope_advance_neg(float& cosv, float& sinv, float cosStep, float sinStep) { // Multiply by exp(-i*step): (cos + i sin) * (cosStep - i sinStep) const float c = cosv; const float s = sinv; cosv = fmaf(c, cosStep, s * sinStep); sinv = fmaf(s, cosStep, -c * sinStep); } // ============================================================ // Neural Network Operations // ============================================================ // RMSNorm kernel: one block per row __global__ void rmsnorm_kernel(float* x, const float* w, int dim, float eps) { int row = blockIdx.x; float* rowData = x + row * dim; float sum = 0.0f; for (int i = threadIdx.x; i < dim; i += blockDim.x) { float v = rowData[i]; sum = fmaf(v, v, sum); } // Warp reduce for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } __shared__ float warpSum[8]; __shared__ float rms; int lane = threadIdx.x & 31; int warp = threadIdx.x >> 5; if (lane == 0) { warpSum[warp] = sum; } __syncthreads(); if (warp == 0) { float v = (lane < 8) ? warpSum[lane] : 0.0f; for (int offset = 16; offset > 0; offset >>= 1) { v += __shfl_down_sync(0xffffffff, v, offset); } if (lane == 0) { rms = rsqrtf(v / dim + eps); } } __syncthreads(); for (int i = threadIdx.x; i < dim; i += blockDim.x) { rowData[i] = rowData[i] * rms * w[i]; } } __global__ void paged_attention_batch_kernel_f16kv( const float* Q, const unsigned short* const* KBlocksFlat, const unsigned short* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale ) { const int tok = blockIdx.x; const int head = blockIdx.y; const int lane = threadIdx.x & 31; if (tok >= numTokens) { return; } // One warp per (tok, head) if (threadIdx.x >= 32) { return; } const int kvHead = head / (numHeads / numKVHeads); const float* q = Q + tok * numHeads * headDim + head * headDim; float* o = out + tok * numHeads * headDim + head * headDim; const int kvLen = kvLens[tok]; const int qPos = queryPos[tok]; const int base = blockOffsets[tok]; const int kvStride = numKVHeads * headDim; const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1); // Cache Q in registers (per lane) to avoid reloading it for every KV token. float qreg[8]; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; qreg[i] = (d < headDim) ? q[d] : 0.0f; } // Support headDim up to 256 (<= 8 values per lane) float acc[8]; #pragma unroll for (int i = 0; i < 8; i++) { acc[i] = 0.0f; } float m = -INFINITY; float l = 0.0f; for (int kv = 0; kv < effectiveLen; kv++) { const int bidx = kv / blockSize; const int boff = kv % blockSize; const __half* kBlock = reinterpret_cast(KBlocksFlat[base + bidx]); const __half* k = kBlock + boff * kvStride + kvHead * headDim; float dot = 0.0f; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { dot = fmaf(qreg[i], __half2float(k[d]), dot); } } for (int offset = 16; offset > 0; offset >>= 1) { dot += __shfl_down_sync(0xffffffff, dot, offset); } dot = __shfl_sync(0xffffffff, dot, 0); const float score = dot * scale; float alpha = 1.0f; float beta = 0.0f; if (lane == 0) { const float newM = fmaxf(m, score); alpha = __expf(m - newM); beta = __expf(score - newM); m = newM; l = l * alpha + beta; } alpha = __shfl_sync(0xffffffff, alpha, 0); beta = __shfl_sync(0xffffffff, beta, 0); l = __shfl_sync(0xffffffff, l, 0); const __half* vBlock = reinterpret_cast(VBlocksFlat[base + bidx]); const __half* v = vBlock + boff * kvStride + kvHead * headDim; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { acc[i] = acc[i] * alpha + beta * __half2float(v[d]); } } } const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { o[d] = acc[i] * invL; } } } // Fused RoPE + paged attention (batch, f16 KV). Expects un-rotated Q/K. __global__ void paged_attention_batch_kernel_f16kv_rope( const float* Q, const unsigned short* const* KBlocksFlat, const unsigned short* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale ) { const int tok = blockIdx.x; const int head = blockIdx.y; const int lane = threadIdx.x & 31; if (tok >= numTokens) { return; } // One warp per (tok, head) if (threadIdx.x >= 32) { return; } const int kvHead = head / (numHeads / numKVHeads); const float* q = Q + tok * numHeads * headDim + head * headDim; float* o = out + tok * numHeads * headDim + head * headDim; const int kvLen = kvLens[tok]; const int qPos = queryPos[tok]; const int base = blockOffsets[tok]; const int kvStride = numKVHeads * headDim; const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1); const int halfDim = headDim >> 1; // Cache Q pairs + per-dim RoPE phase for delta=(qPos - kv) with kv starting at 0. float q0[4]; float q1[4]; float cosStep[4]; float sinStep[4]; float cosDelta[4]; float sinDelta[4]; int pairCount = 0; for (int j = lane; j < halfDim; j += 32) { q0[pairCount] = q[j]; q1[pairCount] = q[j + halfDim]; cosStep[pairCount] = rope_cos_step_const[j]; sinStep[pairCount] = rope_sin_step_const[j]; const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]); const float2 ph = complex_pow_int(baseStep, qPos); cosDelta[pairCount] = ph.x; sinDelta[pairCount] = ph.y; pairCount++; } // Support headDim up to 256. float acc[8]; #pragma unroll for (int i = 0; i < 8; i++) { acc[i] = 0.0f; } float m = -INFINITY; float l = 0.0f; for (int kv = 0; kv < effectiveLen; kv++) { const int bidx = kv / blockSize; const int boff = kv % blockSize; const __half* kBlock = reinterpret_cast(KBlocksFlat[base + bidx]); const __half* k = kBlock + boff * kvStride + kvHead * headDim; float dot = 0.0f; #pragma unroll for (int pi = 0; pi < 4; pi++) { if (pi >= pairCount) { break; } const int j = lane + 32 * pi; if (j >= halfDim) { continue; } const float k0 = __half2float(k[j]); const float k1 = __half2float(k[j + halfDim]); const float a = fmaf(q0[pi], k0, q1[pi] * k1); // q0*k0 + q1*k1 const float b = fmaf(q0[pi], k1, -q1[pi] * k0); // q0*k1 - q1*k0 dot = fmaf(cosDelta[pi], a, dot); dot = fmaf(sinDelta[pi], b, dot); } for (int offset = 16; offset > 0; offset >>= 1) { dot += __shfl_down_sync(0xffffffff, dot, offset); } dot = __shfl_sync(0xffffffff, dot, 0); // Advance delta -> delta-1 for next kv. #pragma unroll for (int pi = 0; pi < 4; pi++) { if (pi >= pairCount) { break; } rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]); } const float score = dot * scale; float alpha = 1.0f; float beta = 0.0f; if (lane == 0) { const float newM = fmaxf(m, score); alpha = __expf(m - newM); beta = __expf(score - newM); m = newM; l = l * alpha + beta; } alpha = __shfl_sync(0xffffffff, alpha, 0); beta = __shfl_sync(0xffffffff, beta, 0); l = __shfl_sync(0xffffffff, l, 0); const __half* vBlock = reinterpret_cast(VBlocksFlat[base + bidx]); const __half* v = vBlock + boff * kvStride + kvHead * headDim; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { acc[i] = acc[i] * alpha + beta * __half2float(v[d]); } } } const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { o[d] = acc[i] * invL; } } } __global__ void cast_f32_to_f16_kernel(const float* src, __half* dst, int n) { int i = blockIdx.x * blockDim.x + threadIdx.x; if (i < n) { dst[i] = __float2half_rn(src[i]); } } __global__ void paged_attention_kernel_f16kv( const float* Q, const unsigned short* const* KBlocks, const unsigned short* const* VBlocks, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos ) { const int seq = blockIdx.x; const int head = blockIdx.y; const int lane = threadIdx.x & 31; if (seq >= seqLen) { return; } // One warp per (seq, head) if (threadIdx.x >= 32) { return; } const int kvHead = head / (numHeads / numKVHeads); const float* q = Q + seq * numHeads * headDim + head * headDim; float* o = out + seq * numHeads * headDim + head * headDim; const int kvStride = numKVHeads * headDim; const int queryPos = startPos + seq; const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1); // Cache Q in registers (per lane) to avoid reloading it for every KV token. float qreg[8]; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; qreg[i] = (d < headDim) ? q[d] : 0.0f; } float acc[8]; #pragma unroll for (int i = 0; i < 8; i++) { acc[i] = 0.0f; } float m = -INFINITY; float l = 0.0f; for (int kv = 0; kv < effectiveLen; kv++) { const int blockIdxKV = kv / blockSize; const int blockOff = kv % blockSize; const __half* kBlock = reinterpret_cast(KBlocks[blockIdxKV]); const __half* k = kBlock + blockOff * kvStride + kvHead * headDim; float dot = 0.0f; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { dot = fmaf(qreg[i], __half2float(k[d]), dot); } } for (int offset = 16; offset > 0; offset >>= 1) { dot += __shfl_down_sync(0xffffffff, dot, offset); } dot = __shfl_sync(0xffffffff, dot, 0); const float score = dot * scale; float alpha = 1.0f; float beta = 0.0f; if (lane == 0) { const float newM = fmaxf(m, score); alpha = __expf(m - newM); beta = __expf(score - newM); m = newM; l = l * alpha + beta; } alpha = __shfl_sync(0xffffffff, alpha, 0); beta = __shfl_sync(0xffffffff, beta, 0); l = __shfl_sync(0xffffffff, l, 0); const __half* vBlock = reinterpret_cast(VBlocks[blockIdxKV]); const __half* v = vBlock + blockOff * kvStride + kvHead * headDim; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { acc[i] = acc[i] * alpha + beta * __half2float(v[d]); } } } const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { o[d] = acc[i] * invL; } } } // Fused RoPE + paged attention (single, f16 KV). Expects un-rotated Q/K. __global__ void paged_attention_kernel_f16kv_rope( const float* Q, const unsigned short* const* KBlocks, const unsigned short* const* VBlocks, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos ) { const int seq = blockIdx.x; const int head = blockIdx.y; const int lane = threadIdx.x & 31; if (seq >= seqLen) { return; } // One warp per (seq, head) if (threadIdx.x >= 32) { return; } const int kvHead = head / (numHeads / numKVHeads); const float* q = Q + seq * numHeads * headDim + head * headDim; float* o = out + seq * numHeads * headDim + head * headDim; const int kvStride = numKVHeads * headDim; const int queryPos = startPos + seq; const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1); const int halfDim = headDim >> 1; float q0[4]; float q1[4]; float cosStep[4]; float sinStep[4]; float cosDelta[4]; float sinDelta[4]; int pairCount = 0; for (int j = lane; j < halfDim; j += 32) { q0[pairCount] = q[j]; q1[pairCount] = q[j + halfDim]; cosStep[pairCount] = rope_cos_step_const[j]; sinStep[pairCount] = rope_sin_step_const[j]; const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]); const float2 ph = complex_pow_int(baseStep, queryPos); cosDelta[pairCount] = ph.x; sinDelta[pairCount] = ph.y; pairCount++; } float acc[8]; #pragma unroll for (int i = 0; i < 8; i++) { acc[i] = 0.0f; } float m = -INFINITY; float l = 0.0f; for (int kv = 0; kv < effectiveLen; kv++) { const int blockIdxKV = kv / blockSize; const int blockOff = kv % blockSize; const __half* kBlock = reinterpret_cast(KBlocks[blockIdxKV]); const __half* k = kBlock + blockOff * kvStride + kvHead * headDim; float dot = 0.0f; #pragma unroll for (int pi = 0; pi < 4; pi++) { if (pi >= pairCount) { break; } const int j = lane + 32 * pi; if (j >= halfDim) { continue; } const float k0 = __half2float(k[j]); const float k1 = __half2float(k[j + halfDim]); const float a = fmaf(q0[pi], k0, q1[pi] * k1); const float b = fmaf(q0[pi], k1, -q1[pi] * k0); dot = fmaf(cosDelta[pi], a, dot); dot = fmaf(sinDelta[pi], b, dot); } for (int offset = 16; offset > 0; offset >>= 1) { dot += __shfl_down_sync(0xffffffff, dot, offset); } dot = __shfl_sync(0xffffffff, dot, 0); // Advance delta -> delta-1 for next kv. #pragma unroll for (int pi = 0; pi < 4; pi++) { if (pi >= pairCount) { break; } rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]); } const float score = dot * scale; float alpha = 1.0f; float beta = 0.0f; if (lane == 0) { const float newM = fmaxf(m, score); alpha = __expf(m - newM); beta = __expf(score - newM); m = newM; l = l * alpha + beta; } alpha = __shfl_sync(0xffffffff, alpha, 0); beta = __shfl_sync(0xffffffff, beta, 0); l = __shfl_sync(0xffffffff, l, 0); const __half* vBlock = reinterpret_cast(VBlocks[blockIdxKV]); const __half* v = vBlock + blockOff * kvStride + kvHead * headDim; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { acc[i] = acc[i] * alpha + beta * __half2float(v[d]); } } } const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { o[d] = acc[i] * invL; } } } template __device__ __forceinline__ float load_kv(const T* p, int idx) { return p[idx]; } template <> __device__ __forceinline__ float load_kv<__half>(const __half* p, int idx) { return __half2float(p[idx]); } template __global__ void paged_attention_split_kv_kernel( const float* Q, const T* const* KBlocks, const T* const* VBlocks, float* partialMax, float* partialSum, float* partialOut, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos, int numSplits, int splitSize ) { const int seq = blockIdx.x; const int head = blockIdx.y; const int split = blockIdx.z; const int lane = threadIdx.x & 31; if (seq >= seqLen) { return; } // One warp per (seq, head, split). if (threadIdx.x >= 32) { return; } const int kvHead = head / (numHeads / numKVHeads); const float* q = Q + seq * numHeads * headDim + head * headDim; const int kvStride = numKVHeads * headDim; const int queryPos = startPos + seq; const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1); const int splitStart = split * splitSize; const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen; const size_t qh = (size_t)seq * (size_t)numHeads + (size_t)head; const size_t splitIdx = qh * (size_t)numSplits + (size_t)split; float* outVec = partialOut + splitIdx * (size_t)headDim; if (splitStart >= splitEnd) { if (lane == 0) { partialMax[splitIdx] = -INFINITY; partialSum[splitIdx] = 0.0f; } #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { outVec[d] = 0.0f; } } return; } const int halfDim = headDim >> 1; const int ropeExp = queryPos - splitStart; float qreg[8]; float q0[4]; float q1[4]; float cosStep[4]; float sinStep[4]; float cosDelta[4]; float sinDelta[4]; int pairCount = 0; if (kUseRoPE) { for (int j = lane; j < halfDim; j += 32) { q0[pairCount] = q[j]; q1[pairCount] = q[j + halfDim]; cosStep[pairCount] = rope_cos_step_const[j]; sinStep[pairCount] = rope_sin_step_const[j]; const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]); const float2 ph = complex_pow_int(baseStep, ropeExp); cosDelta[pairCount] = ph.x; sinDelta[pairCount] = ph.y; pairCount++; } } else { #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; qreg[i] = (d < headDim) ? q[d] : 0.0f; } } float acc[8]; #pragma unroll for (int i = 0; i < 8; i++) { acc[i] = 0.0f; } float m = -INFINITY; float l = 0.0f; for (int kv = splitStart; kv < splitEnd; kv++) { const int blockIdxKV = kv / blockSize; const int blockOff = kv % blockSize; const T* kBlock = KBlocks[blockIdxKV]; const T* k = kBlock + blockOff * kvStride + kvHead * headDim; float dot = 0.0f; if (kUseRoPE) { #pragma unroll for (int pi = 0; pi < 4; pi++) { if (pi >= pairCount) { break; } const int j = lane + 32 * pi; if (j >= halfDim) { continue; } const float k0 = load_kv(k, j); const float k1 = load_kv(k, j + halfDim); const float a = fmaf(q0[pi], k0, q1[pi] * k1); const float b = fmaf(q0[pi], k1, -q1[pi] * k0); dot = fmaf(cosDelta[pi], a, dot); dot = fmaf(sinDelta[pi], b, dot); } } else { #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { dot = fmaf(qreg[i], load_kv(k, d), dot); } } } for (int offset = 16; offset > 0; offset >>= 1) { dot += __shfl_down_sync(0xffffffff, dot, offset); } dot = __shfl_sync(0xffffffff, dot, 0); if (kUseRoPE) { #pragma unroll for (int pi = 0; pi < 4; pi++) { if (pi >= pairCount) { break; } rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]); } } const float score = dot * scale; float alpha = 1.0f; float beta = 0.0f; if (lane == 0) { const float newM = fmaxf(m, score); alpha = __expf(m - newM); beta = __expf(score - newM); m = newM; l = l * alpha + beta; } alpha = __shfl_sync(0xffffffff, alpha, 0); beta = __shfl_sync(0xffffffff, beta, 0); const T* vBlock = VBlocks[blockIdxKV]; const T* v = vBlock + blockOff * kvStride + kvHead * headDim; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { acc[i] = acc[i] * alpha + beta * load_kv(v, d); } } } if (lane == 0) { partialMax[splitIdx] = m; partialSum[splitIdx] = l; } #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { outVec[d] = acc[i]; } } } template __global__ void paged_attention_split_kv_batch_kernel( const float* Q, const T* const* KBlocksFlat, const T* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* partialMax, float* partialSum, float* partialOut, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int numSplits, int splitSize ) { const int tok = blockIdx.x; const int head = blockIdx.y; const int split = blockIdx.z; const int lane = threadIdx.x & 31; if (tok >= numTokens) { return; } // One warp per (tok, head, split). if (threadIdx.x >= 32) { return; } const int kvHead = head / (numHeads / numKVHeads); const float* q = Q + tok * numHeads * headDim + head * headDim; const int kvLen = kvLens[tok]; const int qPos = queryPos[tok]; const int base = blockOffsets[tok]; const int kvStride = numKVHeads * headDim; const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1); const int splitStart = split * splitSize; const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen; const size_t qh = (size_t)tok * (size_t)numHeads + (size_t)head; const size_t splitIdx = qh * (size_t)numSplits + (size_t)split; float* outVec = partialOut + splitIdx * (size_t)headDim; if (splitStart >= splitEnd) { if (lane == 0) { partialMax[splitIdx] = -INFINITY; partialSum[splitIdx] = 0.0f; } #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { outVec[d] = 0.0f; } } return; } const int halfDim = headDim >> 1; const int ropeExp = qPos - splitStart; float qreg[8]; float q0[4]; float q1[4]; float cosStep[4]; float sinStep[4]; float cosDelta[4]; float sinDelta[4]; int pairCount = 0; if (kUseRoPE) { for (int j = lane; j < halfDim; j += 32) { q0[pairCount] = q[j]; q1[pairCount] = q[j + halfDim]; cosStep[pairCount] = rope_cos_step_const[j]; sinStep[pairCount] = rope_sin_step_const[j]; const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]); const float2 ph = complex_pow_int(baseStep, ropeExp); cosDelta[pairCount] = ph.x; sinDelta[pairCount] = ph.y; pairCount++; } } else { #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; qreg[i] = (d < headDim) ? q[d] : 0.0f; } } float acc[8]; #pragma unroll for (int i = 0; i < 8; i++) { acc[i] = 0.0f; } float m = -INFINITY; float l = 0.0f; for (int kv = splitStart; kv < splitEnd; kv++) { const int bidx = kv / blockSize; const int boff = kv % blockSize; const T* kBlock = KBlocksFlat[base + bidx]; const T* k = kBlock + boff * kvStride + kvHead * headDim; float dot = 0.0f; if (kUseRoPE) { #pragma unroll for (int pi = 0; pi < 4; pi++) { if (pi >= pairCount) { break; } const int j = lane + 32 * pi; if (j >= halfDim) { continue; } const float k0 = load_kv(k, j); const float k1 = load_kv(k, j + halfDim); const float a = fmaf(q0[pi], k0, q1[pi] * k1); const float b = fmaf(q0[pi], k1, -q1[pi] * k0); dot = fmaf(cosDelta[pi], a, dot); dot = fmaf(sinDelta[pi], b, dot); } } else { #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { dot = fmaf(qreg[i], load_kv(k, d), dot); } } } for (int offset = 16; offset > 0; offset >>= 1) { dot += __shfl_down_sync(0xffffffff, dot, offset); } dot = __shfl_sync(0xffffffff, dot, 0); if (kUseRoPE) { #pragma unroll for (int pi = 0; pi < 4; pi++) { if (pi >= pairCount) { break; } rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]); } } const float score = dot * scale; float alpha = 1.0f; float beta = 0.0f; if (lane == 0) { const float newM = fmaxf(m, score); alpha = __expf(m - newM); beta = __expf(score - newM); m = newM; l = l * alpha + beta; } alpha = __shfl_sync(0xffffffff, alpha, 0); beta = __shfl_sync(0xffffffff, beta, 0); const T* vBlock = VBlocksFlat[base + bidx]; const T* v = vBlock + boff * kvStride + kvHead * headDim; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { acc[i] = acc[i] * alpha + beta * load_kv(v, d); } } } if (lane == 0) { partialMax[splitIdx] = m; partialSum[splitIdx] = l; } #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { outVec[d] = acc[i]; } } } __global__ void paged_attention_split_kv_reduce_kernel( const float* partialMax, const float* partialSum, const float* partialOut, float* out, int queryCount, int numHeads, int headDim, int numSplits ) { const int q = blockIdx.x; const int head = blockIdx.y; const int lane = threadIdx.x & 31; if (q >= queryCount) { return; } // One warp per (q, head). if (threadIdx.x >= 32) { return; } const size_t qh = (size_t)q * (size_t)numHeads + (size_t)head; const size_t base = qh * (size_t)numSplits; float gmax = -INFINITY; if (lane == 0) { for (int s = 0; s < numSplits; s++) { gmax = fmaxf(gmax, partialMax[base + (size_t)s]); } } gmax = __shfl_sync(0xffffffff, gmax, 0); if (gmax == -INFINITY) { const size_t outBase = qh * (size_t)headDim; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { out[outBase + (size_t)d] = 0.0f; } } return; } float sum = 0.0f; if (lane == 0) { for (int s = 0; s < numSplits; s++) { const float l = partialSum[base + (size_t)s]; sum += l * __expf(partialMax[base + (size_t)s] - gmax); } } sum = __shfl_sync(0xffffffff, sum, 0); const float invSum = (sum > 0.0f) ? (1.0f / sum) : 0.0f; float acc[8]; #pragma unroll for (int i = 0; i < 8; i++) { acc[i] = 0.0f; } for (int s = 0; s < numSplits; s++) { float scale = 0.0f; if (lane == 0) { scale = __expf(partialMax[base + (size_t)s] - gmax); } scale = __shfl_sync(0xffffffff, scale, 0); const float* vec = partialOut + (base + (size_t)s) * (size_t)headDim; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { acc[i] = fmaf(scale, vec[d], acc[i]); } } } const size_t outBase = qh * (size_t)headDim; #pragma unroll for (int i = 0; i < 8; i++) { const int d = lane + 32 * i; if (d < headDim) { out[outBase + (size_t)d] = acc[i] * invSum; } } } int cuda_rmsnorm_f32(float* x, const float* w, int seqLen, int dim, float eps) { int threads = 256; rmsnorm_kernel<<>>(x, w, dim, eps); CHECK_CUDA(cudaGetLastError()); return 0; } int cuda_cast_f32_to_f16(const float* src, unsigned short* dst, int n) { int threads = 256; int blocks = (n + threads - 1) / threads; cast_f32_to_f16_kernel<<>>(src, reinterpret_cast<__half*>(dst), n); CHECK_CUDA(cudaGetLastError()); return 0; } int cuda_paged_attention_f32_f16kv( const float* Q, const unsigned short* const* KBlocks, const unsigned short* const* VBlocks, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos ) { // Split-KV Flash Decoding for long contexts. const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen); const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize; const int qhCount = seqLen * numHeads; const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold); if (!useSplit) { dim3 blocks(seqLen, numHeads); int threads = 32; paged_attention_kernel_f16kv<<>>( Q, KBlocks, VBlocks, out, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize, scale, startPos ); CHECK_CUDA(cudaGetLastError()); return 0; } const size_t splitCount = (size_t)qhCount * (size_t)numSplits; const size_t totalFloats = splitCount * (size_t)(headDim + 2); float* buf = reinterpret_cast(cuda_malloc(totalFloats * sizeof(float))); if (buf == nullptr) { return 1; } float* partialMax = buf; float* partialSum = partialMax + splitCount; float* partialOut = partialSum + splitCount; dim3 blocks1(seqLen, numHeads, numSplits); paged_attention_split_kv_kernel<__half><<>>( Q, reinterpret_cast(KBlocks), reinterpret_cast(VBlocks), partialMax, partialSum, partialOut, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize, scale, startPos, numSplits, kPagedAttentionSplitSize ); CHECK_CUDA(cudaGetLastError()); dim3 blocks2(seqLen, numHeads); paged_attention_split_kv_reduce_kernel<<>>( partialMax, partialSum, partialOut, out, seqLen, numHeads, headDim, numSplits ); CHECK_CUDA(cudaGetLastError()); cuda_free(buf); return 0; } int cuda_paged_attention_rope_f32_f16kv( const float* Q, const unsigned short* const* KBlocks, const unsigned short* const* VBlocks, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos, float theta ) { if (ensure_rope_step_table(headDim, theta) != 0) { return 1; } // Split-KV Flash Decoding for long contexts. const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen); const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize; const int qhCount = seqLen * numHeads; const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold); if (!useSplit) { dim3 blocks(seqLen, numHeads); int threads = 32; paged_attention_kernel_f16kv_rope<<>>( Q, KBlocks, VBlocks, out, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize, scale, startPos ); CHECK_CUDA(cudaGetLastError()); return 0; } const size_t splitCount = (size_t)qhCount * (size_t)numSplits; const size_t totalFloats = splitCount * (size_t)(headDim + 2); float* buf = reinterpret_cast(cuda_malloc(totalFloats * sizeof(float))); if (buf == nullptr) { return 1; } float* partialMax = buf; float* partialSum = partialMax + splitCount; float* partialOut = partialSum + splitCount; dim3 blocks1(seqLen, numHeads, numSplits); paged_attention_split_kv_kernel<__half, true><<>>( Q, reinterpret_cast(KBlocks), reinterpret_cast(VBlocks), partialMax, partialSum, partialOut, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize, scale, startPos, numSplits, kPagedAttentionSplitSize ); CHECK_CUDA(cudaGetLastError()); dim3 blocks2(seqLen, numHeads); paged_attention_split_kv_reduce_kernel<<>>( partialMax, partialSum, partialOut, out, seqLen, numHeads, headDim, numSplits ); CHECK_CUDA(cudaGetLastError()); cuda_free(buf); return 0; } int cuda_paged_attention_batch_f32_f16kv( const float* Q, const unsigned short* const* KBlocksFlat, const unsigned short* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int maxKvLen ) { const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize; const int qhCount = numTokens * numHeads; const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold); if (!useSplit) { dim3 blocks(numTokens, numHeads); int threads = 32; paged_attention_batch_kernel_f16kv<<>>( Q, KBlocksFlat, VBlocksFlat, blockOffsets, kvLens, queryPos, out, numTokens, numHeads, numKVHeads, headDim, blockSize, scale ); CHECK_CUDA(cudaGetLastError()); return 0; } const size_t splitCount = (size_t)qhCount * (size_t)numSplits; const size_t totalFloats = splitCount * (size_t)(headDim + 2); float* buf = reinterpret_cast(cuda_malloc(totalFloats * sizeof(float))); if (buf == nullptr) { return 1; } float* partialMax = buf; float* partialSum = partialMax + splitCount; float* partialOut = partialSum + splitCount; dim3 blocks1(numTokens, numHeads, numSplits); paged_attention_split_kv_batch_kernel<__half><<>>( Q, reinterpret_cast(KBlocksFlat), reinterpret_cast(VBlocksFlat), blockOffsets, kvLens, queryPos, partialMax, partialSum, partialOut, numTokens, numHeads, numKVHeads, headDim, blockSize, scale, numSplits, kPagedAttentionSplitSize ); CHECK_CUDA(cudaGetLastError()); dim3 blocks2(numTokens, numHeads); paged_attention_split_kv_reduce_kernel<<>>( partialMax, partialSum, partialOut, out, numTokens, numHeads, headDim, numSplits ); CHECK_CUDA(cudaGetLastError()); cuda_free(buf); return 0; } int cuda_paged_attention_rope_batch_f32_f16kv( const float* Q, const unsigned short* const* KBlocksFlat, const unsigned short* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int maxKvLen, float theta ) { if (ensure_rope_step_table(headDim, theta) != 0) { return 1; } const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize; const int qhCount = numTokens * numHeads; const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold); if (!useSplit) { dim3 blocks(numTokens, numHeads); int threads = 32; paged_attention_batch_kernel_f16kv_rope<<>>( Q, KBlocksFlat, VBlocksFlat, blockOffsets, kvLens, queryPos, out, numTokens, numHeads, numKVHeads, headDim, blockSize, scale ); CHECK_CUDA(cudaGetLastError()); return 0; } const size_t splitCount = (size_t)qhCount * (size_t)numSplits; const size_t totalFloats = splitCount * (size_t)(headDim + 2); float* buf = reinterpret_cast(cuda_malloc(totalFloats * sizeof(float))); if (buf == nullptr) { return 1; } float* partialMax = buf; float* partialSum = partialMax + splitCount; float* partialOut = partialSum + splitCount; dim3 blocks1(numTokens, numHeads, numSplits); paged_attention_split_kv_batch_kernel<__half, true><<>>( Q, reinterpret_cast(KBlocksFlat), reinterpret_cast(VBlocksFlat), blockOffsets, kvLens, queryPos, partialMax, partialSum, partialOut, numTokens, numHeads, numKVHeads, headDim, blockSize, scale, numSplits, kPagedAttentionSplitSize ); CHECK_CUDA(cudaGetLastError()); dim3 blocks2(numTokens, numHeads); paged_attention_split_kv_reduce_kernel<<>>( partialMax, partialSum, partialOut, out, numTokens, numHeads, headDim, numSplits ); CHECK_CUDA(cudaGetLastError()); cuda_free(buf); return 0; } __global__ void paged_attention_batch_kernel( const float* Q, const float* const* KBlocksFlat, const float* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale ); __global__ void paged_attention_batch_kernel_f16kv( const float* Q, const unsigned short* const* KBlocksFlat, const unsigned short* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale ); int cuda_paged_attention_batch_f32( const float* Q, const float* const* KBlocksFlat, const float* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int maxKvLen ) { const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize; const int qhCount = numTokens * numHeads; const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold); if (!useSplit) { dim3 blocks(numTokens, numHeads); int threads = 256; paged_attention_batch_kernel<<>>( Q, KBlocksFlat, VBlocksFlat, blockOffsets, kvLens, queryPos, out, numTokens, numHeads, numKVHeads, headDim, blockSize, scale ); CHECK_CUDA(cudaGetLastError()); return 0; } const size_t splitCount = (size_t)qhCount * (size_t)numSplits; const size_t totalFloats = splitCount * (size_t)(headDim + 2); float* buf = reinterpret_cast(cuda_malloc(totalFloats * sizeof(float))); if (buf == nullptr) { return 1; } float* partialMax = buf; float* partialSum = partialMax + splitCount; float* partialOut = partialSum + splitCount; dim3 blocks1(numTokens, numHeads, numSplits); paged_attention_split_kv_batch_kernel<<>>( Q, KBlocksFlat, VBlocksFlat, blockOffsets, kvLens, queryPos, partialMax, partialSum, partialOut, numTokens, numHeads, numKVHeads, headDim, blockSize, scale, numSplits, kPagedAttentionSplitSize ); CHECK_CUDA(cudaGetLastError()); dim3 blocks2(numTokens, numHeads); paged_attention_split_kv_reduce_kernel<<>>( partialMax, partialSum, partialOut, out, numTokens, numHeads, headDim, numSplits ); CHECK_CUDA(cudaGetLastError()); cuda_free(buf); return 0; } // RoPE kernel __global__ void rope_kernel(float* x, const int* positions, int totalDim, int headDim, float theta) { int seq = blockIdx.x; int pos = positions[seq]; float* rowData = x + seq * totalDim; int halfDim = headDim / 2; // Each thread handles one (j, j+halfDim) pair across all heads. // Compute sin/cos once per j and reuse across heads. for (int j = threadIdx.x; j < halfDim; j += blockDim.x) { const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim); const float freq = pos * invFreq; float sinF, cosF; sincosf(freq, &sinF, &cosF); for (int headStart = 0; headStart < totalDim; headStart += headDim) { const int idx0 = headStart + j; const int idx1 = idx0 + halfDim; const float v0 = rowData[idx0]; const float v1 = rowData[idx1]; rowData[idx0] = v0 * cosF - v1 * sinF; rowData[idx1] = v1 * cosF + v0 * sinF; } } } int cuda_rope_f32(float* x, const int* positions, int seqLen, int numHeads, int headDim, float theta) { int threads = (headDim / 2) < 128 ? (headDim / 2) : 128; rope_kernel<<>>(x, positions, numHeads * headDim, headDim, theta); CHECK_CUDA(cudaGetLastError()); return 0; } // Optimized RoPE kernel for single position (scalar pos) __global__ void rope_kernel_single(float* x, int pos, int totalDim, int headDim, float theta) { // seq is always 0 for single token float* rowData = x; // x + 0 * totalDim int halfDim = headDim / 2; // Each thread handles one (j, j+halfDim) pair across all heads. // Compute sin/cos once per j and reuse across heads. for (int j = threadIdx.x; j < halfDim; j += blockDim.x) { const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim); const float freq = pos * invFreq; float sinF, cosF; sincosf(freq, &sinF, &cosF); for (int headStart = 0; headStart < totalDim; headStart += headDim) { const int idx0 = headStart + j; const int idx1 = idx0 + halfDim; const float v0 = rowData[idx0]; const float v1 = rowData[idx1]; rowData[idx0] = v0 * cosF - v1 * sinF; rowData[idx1] = v1 * cosF + v0 * sinF; } } } int cuda_rope_f32_single(float* x, int pos, int numHeads, int headDim, float theta) { int threads = (headDim / 2) < 128 ? (headDim / 2) : 128; rope_kernel_single<<<1, threads>>>(x, pos, numHeads * headDim, headDim, theta); CHECK_CUDA(cudaGetLastError()); return 0; } // Softmax kernel: one block per row __global__ void softmax_kernel(float* x, int cols) { int row = blockIdx.x; float* rowData = x + row * cols; __shared__ float smax[256]; __shared__ float ssum[256]; // Find max float threadMax = -INFINITY; for (int i = threadIdx.x; i < cols; i += blockDim.x) { threadMax = fmaxf(threadMax, rowData[i]); } smax[threadIdx.x] = threadMax; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]); } __syncthreads(); } float maxVal = smax[0]; // Compute exp and sum float threadSum = 0.0f; for (int i = threadIdx.x; i < cols; i += blockDim.x) { float val = expf(rowData[i] - maxVal); rowData[i] = val; threadSum += val; } ssum[threadIdx.x] = threadSum; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { ssum[threadIdx.x] += ssum[threadIdx.x + s]; } __syncthreads(); } float sum = ssum[0]; // Normalize for (int i = threadIdx.x; i < cols; i += blockDim.x) { rowData[i] /= sum; } } int cuda_softmax_f32(float* x, int rows, int cols) { int threads = 256; softmax_kernel<<>>(x, cols); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // Top-K Logits Selection (for sampling without full D2H) // ============================================================ #define TOPK_THREADS 256 #define TOPK_LOCAL 8 #define TOPK_SEGMENT (TOPK_THREADS * TOPK_LOCAL) // 2048 #define TOPK_MAX_K 64 static __device__ __forceinline__ float apply_rep_penalty(float score, bool hit, float penalty) { if (!hit) return score; if (score > 0.0f) return score / penalty; return score * penalty; } __global__ void topk_logits_kernel( const float* logits, int vocab, const int* rep_ids, int rep_count, float rep_penalty, int k, int* out_ids, float* out_scores ) { const int blockStart = blockIdx.x * TOPK_SEGMENT; const int tid = threadIdx.x; if (tid >= TOPK_THREADS) return; float localScores[TOPK_LOCAL]; int localIds[TOPK_LOCAL]; #pragma unroll for (int i = 0; i < TOPK_LOCAL; i++) { localScores[i] = -INFINITY; localIds[i] = -1; } // Each thread processes up to TOPK_LOCAL elements in a contiguous segment. #pragma unroll for (int j = 0; j < TOPK_LOCAL; j++) { int idx = blockStart + tid + j * TOPK_THREADS; if (idx >= vocab) break; float score = logits[idx]; bool hit = false; // rep_count is small (<=64). Linear scan is fine. for (int r = 0; r < rep_count; r++) { if (rep_ids[r] == idx) { hit = true; break; } } score = apply_rep_penalty(score, hit, rep_penalty); // Insert into local top list (descending) int pos = TOPK_LOCAL; for (int t = 0; t < TOPK_LOCAL; t++) { if (score > localScores[t]) { pos = t; break; } } if (pos < TOPK_LOCAL) { for (int t = TOPK_LOCAL - 1; t > pos; t--) { localScores[t] = localScores[t-1]; localIds[t] = localIds[t-1]; } localScores[pos] = score; localIds[pos] = idx; } } __shared__ float shScores[TOPK_SEGMENT]; __shared__ int shIds[TOPK_SEGMENT]; // Write local results to shared candidate pool. #pragma unroll for (int j = 0; j < TOPK_LOCAL; j++) { int out = tid * TOPK_LOCAL + j; shScores[out] = localScores[j]; shIds[out] = localIds[j]; } __syncthreads(); if (tid == 0) { // Block-level exact top-k from TOPK_SEGMENT candidates. if (k > TOPK_MAX_K) k = TOPK_MAX_K; float bestScores[TOPK_MAX_K]; int bestIds[TOPK_MAX_K]; for (int i = 0; i < k; i++) { bestScores[i] = -INFINITY; bestIds[i] = -1; } for (int i = 0; i < TOPK_SEGMENT; i++) { float score = shScores[i]; int id = shIds[i]; if (id < 0) continue; // Insert into best (descending) if (score <= bestScores[k-1]) continue; int pos = k; for (int t = 0; t < k; t++) { if (score > bestScores[t]) { pos = t; break; } } if (pos < k) { for (int t = k - 1; t > pos; t--) { bestScores[t] = bestScores[t-1]; bestIds[t] = bestIds[t-1]; } bestScores[pos] = score; bestIds[pos] = id; } } int base = blockIdx.x * k; for (int i = 0; i < k; i++) { out_ids[base + i] = bestIds[i]; out_scores[base + i] = bestScores[i]; } } } int cuda_topk_logits_f32( const float* logits, int vocab, const int* rep_ids, int rep_count, float rep_penalty, int k, int* out_ids, float* out_scores ) { if (k <= 0) return 0; if (k > TOPK_MAX_K) k = TOPK_MAX_K; int blocks = (vocab + TOPK_SEGMENT - 1) / TOPK_SEGMENT; dim3 grid(blocks); dim3 block(TOPK_THREADS); topk_logits_kernel<<>>(logits, vocab, rep_ids, rep_count, rep_penalty, k, out_ids, out_scores); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // Attention Kernel // Computes: softmax(Q @ K.T / scale + causal_mask) @ V // ============================================================ __global__ void attention_kernel( const float* Q, const float* K, const float* V, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, float scale, int startPos ) { // Each block handles one (seq, head) pair int seq = blockIdx.x; int head = blockIdx.y; int kvHead = head / (numHeads / numKVHeads); // GQA support const float* q = Q + seq * numHeads * headDim + head * headDim; float* o = out + seq * numHeads * headDim + head * headDim; // Shared memory for attention scores extern __shared__ float shared[]; float* scores = shared; // [kvLen] // Compute Q @ K.T for this head for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) { const float* k = K + kv * numKVHeads * headDim + kvHead * headDim; float dot = 0.0f; for (int d = 0; d < headDim; d++) { dot += q[d] * k[d]; } // Apply causal mask int queryPos = startPos + seq; int keyPos = kv; if (keyPos > queryPos) { dot = -INFINITY; } scores[kv] = dot * scale; } __syncthreads(); // Softmax over scores // Find max float maxVal = -INFINITY; for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) { maxVal = fmaxf(maxVal, scores[kv]); } // Reduce max across threads __shared__ float smax[256]; smax[threadIdx.x] = maxVal; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]); } __syncthreads(); } maxVal = smax[0]; // Exp and sum float sum = 0.0f; for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) { float val = expf(scores[kv] - maxVal); scores[kv] = val; sum += val; } __shared__ float ssum[256]; ssum[threadIdx.x] = sum; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { ssum[threadIdx.x] += ssum[threadIdx.x + s]; } __syncthreads(); } sum = ssum[0]; // Normalize for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) { scores[kv] /= sum; } __syncthreads(); // Compute weighted sum of V for (int d = threadIdx.x; d < headDim; d += blockDim.x) { float val = 0.0f; for (int kv = 0; kv < kvLen; kv++) { const float* v = V + kv * numKVHeads * headDim + kvHead * headDim; val += scores[kv] * v[d]; } o[d] = val; } } __global__ void paged_attention_kernel( const float* Q, const float* const* KBlocks, const float* const* VBlocks, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos ) { int seq = blockIdx.x; int head = blockIdx.y; int kvHead = head / (numHeads / numKVHeads); const float* q = Q + seq * numHeads * headDim + head * headDim; float* o = out + seq * numHeads * headDim + head * headDim; const int kvStride = numKVHeads * headDim; int queryPos = startPos + seq; // Pass 1: max float localMax = -INFINITY; for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) { int keyPos = kv; if (keyPos > queryPos) { continue; } const int blockIdxKV = kv / blockSize; const int blockOff = kv % blockSize; const float* kBlock = KBlocks[blockIdxKV]; const float* k = kBlock + blockOff * kvStride + kvHead * headDim; float dot = 0.0f; for (int d = 0; d < headDim; d++) { dot += q[d] * k[d]; } float score = dot * scale; localMax = fmaxf(localMax, score); } __shared__ float smax[256]; smax[threadIdx.x] = localMax; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]); } __syncthreads(); } float maxVal = smax[0]; // Pass 2: sum exp float localSum = 0.0f; for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) { int keyPos = kv; if (keyPos > queryPos) { continue; } const int blockIdxKV = kv / blockSize; const int blockOff = kv % blockSize; const float* kBlock = KBlocks[blockIdxKV]; const float* k = kBlock + blockOff * kvStride + kvHead * headDim; float dot = 0.0f; for (int d = 0; d < headDim; d++) { dot += q[d] * k[d]; } float score = dot * scale; localSum += expf(score - maxVal); } __shared__ float ssum[256]; ssum[threadIdx.x] = localSum; __syncthreads(); for (int s = blockDim.x / 2; s > 0; s >>= 1) { if (threadIdx.x < s) { ssum[threadIdx.x] += ssum[threadIdx.x + s]; } __syncthreads(); } float sumVal = ssum[0]; float invSum = (sumVal > 0.0f) ? (1.0f / sumVal) : 0.0f; // Pass 3: output accumulation for (int d = threadIdx.x; d < headDim; d += blockDim.x) { float outVal = 0.0f; for (int kv = 0; kv < kvLen; kv++) { int keyPos = kv; if (keyPos > queryPos) { break; } const int blockIdxKV = kv / blockSize; const int blockOff = kv % blockSize; const float* kBlock = KBlocks[blockIdxKV]; const float* k = kBlock + blockOff * kvStride + kvHead * headDim; float dot = 0.0f; for (int kd = 0; kd < headDim; kd++) { dot += q[kd] * k[kd]; } float score = dot * scale; float w = expf(score - maxVal) * invSum; const float* vBlock = VBlocks[blockIdxKV]; const float* v = vBlock + blockOff * kvStride + kvHead * headDim; outVal += w * v[d]; } o[d] = outVal; } } __global__ void paged_attention_batch_kernel( const float* Q, const float* const* KBlocksFlat, const float* const* VBlocksFlat, const int* blockOffsets, const int* kvLens, const int* queryPos, float* out, int numTokens, int numHeads, int numKVHeads, int headDim, int blockSize, float scale ) { int tok = blockIdx.x; int head = blockIdx.y; if (tok >= numTokens) { return; } int kvHead = head / (numHeads / numKVHeads); const float* q = Q + tok * numHeads * headDim + head * headDim; float* o = out + tok * numHeads * headDim + head * headDim; int kvLen = kvLens[tok]; int qPos = queryPos[tok]; int base = blockOffsets[tok]; const int kvStride = numKVHeads * headDim; const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1); float acc = 0.0f; if (threadIdx.x >= headDim) { acc = 0.0f; } __shared__ float m; __shared__ float l; __shared__ float alpha; __shared__ float beta; __shared__ float dotShared; if (threadIdx.x == 0) { m = -INFINITY; l = 0.0f; } __syncthreads(); for (int kv = 0; kv < effectiveLen; kv++) { const int bidx = kv / blockSize; const int boff = kv % blockSize; const float* kBlock = KBlocksFlat[base + bidx]; const float* k = kBlock + boff * kvStride + kvHead * headDim; float partial = 0.0f; for (int d = threadIdx.x; d < headDim; d += blockDim.x) { partial = fmaf(q[d], k[d], partial); } // block reduction (sum) for (int offset = 16; offset > 0; offset >>= 1) { partial += __shfl_down_sync(0xffffffff, partial, offset); } __shared__ float warpSum[8]; int lane = threadIdx.x & 31; int warp = threadIdx.x >> 5; if (lane == 0) { warpSum[warp] = partial; } __syncthreads(); if (warp == 0) { float v = (lane < 8) ? warpSum[lane] : 0.0f; for (int offset = 16; offset > 0; offset >>= 1) { v += __shfl_down_sync(0xffffffff, v, offset); } if (lane == 0) { dotShared = v; } } __syncthreads(); float score = dotShared * scale; if (threadIdx.x == 0) { float newM = fmaxf(m, score); float a = expf(m - newM); float b = expf(score - newM); m = newM; l = l * a + b; alpha = a; beta = b; } __syncthreads(); if (threadIdx.x < headDim) { const float* vBlock = VBlocksFlat[base + bidx]; const float* v = vBlock + boff * kvStride + kvHead * headDim; acc = fmaf(beta, v[threadIdx.x], acc * alpha); } __syncthreads(); } if (threadIdx.x < headDim) { float invL = (l > 0.0f) ? (1.0f / l) : 0.0f; o[threadIdx.x] = acc * invL; } } int cuda_attention_f32( const float* Q, const float* K, const float* V, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, float scale, int startPos ) { dim3 blocks(seqLen, numHeads); int threads = 256; size_t sharedMem = kvLen * sizeof(float); attention_kernel<<>>( Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos ); CHECK_CUDA(cudaGetLastError()); return 0; } int cuda_paged_attention_f32( const float* Q, const float* const* KBlocks, const float* const* VBlocks, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, int blockSize, float scale, int startPos ) { // Split-KV Flash Decoding for long contexts. const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen); const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize; const int qhCount = seqLen * numHeads; const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold); if (!useSplit) { dim3 blocks(seqLen, numHeads); int threads = 256; paged_attention_kernel<<>>( Q, KBlocks, VBlocks, out, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize, scale, startPos ); CHECK_CUDA(cudaGetLastError()); return 0; } const size_t splitCount = (size_t)qhCount * (size_t)numSplits; const size_t totalFloats = splitCount * (size_t)(headDim + 2); float* buf = reinterpret_cast(cuda_malloc(totalFloats * sizeof(float))); if (buf == nullptr) { return 1; } float* partialMax = buf; float* partialSum = partialMax + splitCount; float* partialOut = partialSum + splitCount; dim3 blocks1(seqLen, numHeads, numSplits); paged_attention_split_kv_kernel<<>>( Q, KBlocks, VBlocks, partialMax, partialSum, partialOut, seqLen, kvLen, numHeads, numKVHeads, headDim, blockSize, scale, startPos, numSplits, kPagedAttentionSplitSize ); CHECK_CUDA(cudaGetLastError()); dim3 blocks2(seqLen, numHeads); paged_attention_split_kv_reduce_kernel<<>>( partialMax, partialSum, partialOut, out, seqLen, numHeads, headDim, numSplits ); CHECK_CUDA(cudaGetLastError()); cuda_free(buf); return 0; } int cuda_attention_f32_timed( const float* Q, const float* K, const float* V, float* out, int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim, float scale, int startPos, float* ms ) { cudaEvent_t evStart; cudaEvent_t evStop; CHECK_CUDA(cudaEventCreate(&evStart)); CHECK_CUDA(cudaEventCreate(&evStop)); dim3 blocks(seqLen, numHeads); int threads = 256; size_t sharedMem = kvLen * sizeof(float); CHECK_CUDA(cudaEventRecord(evStart)); attention_kernel<<>>( Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos ); CHECK_CUDA(cudaEventRecord(evStop)); CHECK_CUDA(cudaEventSynchronize(evStop)); float elapsed = 0.0f; CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop)); if (ms != NULL) { *ms = elapsed; } CHECK_CUDA(cudaEventDestroy(evStart)); CHECK_CUDA(cudaEventDestroy(evStop)); CHECK_CUDA(cudaGetLastError()); return 0; } // Debug helper int cuda_print_struct_sizes() { printf("GPU Struct Sizes:\n"); printf("BlockQ2_K: %lu\n", sizeof(BlockQ2_K)); printf("BlockQ3_K: %lu\n", sizeof(BlockQ3_K)); printf("BlockQ4_K: %lu\n", sizeof(BlockQ4_K)); printf("BlockQ6_K: %lu\n", sizeof(BlockQ6_K)); printf("BlockQ8_K: %lu\n", sizeof(BlockQ8_K)); return 0; }