|
|
@@ -0,0 +1,2165 @@
|
|
|
+#include "cuda_common.cuh"
|
|
|
+#include <cuda_fp16.h>
|
|
|
+#include <stdint.h>
|
|
|
+#include <string.h>
|
|
|
+
|
|
|
+namespace {
|
|
|
+constexpr int kPagedAttentionSplitSize = 1024;
|
|
|
+constexpr int kPagedAttentionSplitQHThreshold = 4096; // queryCount*numHeads threshold
|
|
|
+} // namespace
|
|
|
+
|
|
|
+// ============================================================
|
|
|
+// Fused RoPE helpers (constant step table + fast complex pow)
|
|
|
+// ============================================================
|
|
|
+
|
|
|
+// Stores cos/sin of per-dimension "step" angle (invFreq) for RoPE.
|
|
|
+// Only indices [0, headDim/2) are used. Max supported headDim is 256.
|
|
|
+__device__ __constant__ float rope_cos_step_const[128];
|
|
|
+__device__ __constant__ float rope_sin_step_const[128];
|
|
|
+
|
|
|
+static int g_rope_step_inited[32] = {0};
|
|
|
+static int g_rope_step_head_dim[32] = {0};
|
|
|
+static uint32_t g_rope_step_theta_bits[32] = {0};
|
|
|
+
|
|
|
+static int ensure_rope_step_table(int headDim, float theta) {
|
|
|
+ if (headDim <= 0 || headDim > 256 || (headDim & 1) != 0) {
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+
|
|
|
+ int dev = 0;
|
|
|
+ CHECK_CUDA(cudaGetDevice(&dev));
|
|
|
+ if (dev < 0 || dev >= 32) {
|
|
|
+ dev = 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ uint32_t thetaBits = 0;
|
|
|
+ memcpy(&thetaBits, &theta, sizeof(thetaBits));
|
|
|
+ if (g_rope_step_inited[dev] && g_rope_step_head_dim[dev] == headDim && g_rope_step_theta_bits[dev] == thetaBits) {
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ float cosStep[128];
|
|
|
+ float sinStep[128];
|
|
|
+ const int halfDim = headDim / 2;
|
|
|
+ for (int j = 0; j < 128; j++) {
|
|
|
+ if (j < halfDim) {
|
|
|
+ // invFreq = theta^(-2j/headDim)
|
|
|
+ const double exp = -2.0 * (double)j / (double)headDim;
|
|
|
+ const double invFreq = pow((double)theta, exp);
|
|
|
+ cosStep[j] = (float)cos(invFreq);
|
|
|
+ sinStep[j] = (float)sin(invFreq);
|
|
|
+ } else {
|
|
|
+ cosStep[j] = 1.0f;
|
|
|
+ sinStep[j] = 0.0f;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ CHECK_CUDA(cudaMemcpyToSymbol(rope_cos_step_const, cosStep, sizeof(cosStep), 0, cudaMemcpyHostToDevice));
|
|
|
+ CHECK_CUDA(cudaMemcpyToSymbol(rope_sin_step_const, sinStep, sizeof(sinStep), 0, cudaMemcpyHostToDevice));
|
|
|
+
|
|
|
+ g_rope_step_inited[dev] = 1;
|
|
|
+ g_rope_step_head_dim[dev] = headDim;
|
|
|
+ g_rope_step_theta_bits[dev] = thetaBits;
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+__device__ __forceinline__ float2 complex_mul_f2(float2 a, float2 b) {
|
|
|
+ // (a.x + i a.y) * (b.x + i b.y)
|
|
|
+ return make_float2(
|
|
|
+ fmaf(a.x, b.x, -a.y * b.y),
|
|
|
+ fmaf(a.x, b.y, a.y * b.x)
|
|
|
+ );
|
|
|
+}
|
|
|
+
|
|
|
+__device__ __forceinline__ float2 complex_pow_int(float2 base, int exp) {
|
|
|
+ float2 result = make_float2(1.0f, 0.0f);
|
|
|
+ float2 b = base;
|
|
|
+ int e = exp;
|
|
|
+ while (e > 0) {
|
|
|
+ if (e & 1) {
|
|
|
+ result = complex_mul_f2(result, b);
|
|
|
+ }
|
|
|
+ b = complex_mul_f2(b, b);
|
|
|
+ e >>= 1;
|
|
|
+ }
|
|
|
+ return result;
|
|
|
+}
|
|
|
+
|
|
|
+__device__ __forceinline__ void rope_advance_neg(float& cosv, float& sinv, float cosStep, float sinStep) {
|
|
|
+ // Multiply by exp(-i*step): (cos + i sin) * (cosStep - i sinStep)
|
|
|
+ const float c = cosv;
|
|
|
+ const float s = sinv;
|
|
|
+ cosv = fmaf(c, cosStep, s * sinStep);
|
|
|
+ sinv = fmaf(s, cosStep, -c * sinStep);
|
|
|
+}
|
|
|
+
|
|
|
+// ============================================================
|
|
|
+// Neural Network Operations
|
|
|
+// ============================================================
|
|
|
+
|
|
|
+// RMSNorm kernel: one block per row
|
|
|
+__global__ void rmsnorm_kernel(float* x, const float* w, int dim, float eps) {
|
|
|
+ int row = blockIdx.x;
|
|
|
+ float* rowData = x + row * dim;
|
|
|
+
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int i = threadIdx.x; i < dim; i += blockDim.x) {
|
|
|
+ float v = rowData[i];
|
|
|
+ sum = fmaf(v, v, sum);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Warp reduce
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ sum += __shfl_down_sync(0xffffffff, sum, offset);
|
|
|
+ }
|
|
|
+
|
|
|
+ __shared__ float warpSum[8];
|
|
|
+ __shared__ float rms;
|
|
|
+
|
|
|
+ int lane = threadIdx.x & 31;
|
|
|
+ int warp = threadIdx.x >> 5;
|
|
|
+ if (lane == 0) {
|
|
|
+ warpSum[warp] = sum;
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ if (warp == 0) {
|
|
|
+ float v = (lane < 8) ? warpSum[lane] : 0.0f;
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ v += __shfl_down_sync(0xffffffff, v, offset);
|
|
|
+ }
|
|
|
+ if (lane == 0) {
|
|
|
+ rms = rsqrtf(v / dim + eps);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ for (int i = threadIdx.x; i < dim; i += blockDim.x) {
|
|
|
+ rowData[i] = rowData[i] * rms * w[i];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+__global__ void paged_attention_batch_kernel_f16kv(
|
|
|
+ const float* Q,
|
|
|
+ const unsigned short* const* KBlocksFlat,
|
|
|
+ const unsigned short* const* VBlocksFlat,
|
|
|
+ const int* blockOffsets,
|
|
|
+ const int* kvLens,
|
|
|
+ const int* queryPos,
|
|
|
+ float* out,
|
|
|
+ int numTokens,
|
|
|
+ int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale
|
|
|
+) {
|
|
|
+ const int tok = blockIdx.x;
|
|
|
+ const int head = blockIdx.y;
|
|
|
+ const int lane = threadIdx.x & 31;
|
|
|
+ if (tok >= numTokens) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // One warp per (tok, head)
|
|
|
+ if (threadIdx.x >= 32) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int kvHead = head / (numHeads / numKVHeads);
|
|
|
+ const float* q = Q + tok * numHeads * headDim + head * headDim;
|
|
|
+ float* o = out + tok * numHeads * headDim + head * headDim;
|
|
|
+
|
|
|
+ const int kvLen = kvLens[tok];
|
|
|
+ const int qPos = queryPos[tok];
|
|
|
+ const int base = blockOffsets[tok];
|
|
|
+
|
|
|
+ const int kvStride = numKVHeads * headDim;
|
|
|
+ const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
|
|
|
+
|
|
|
+ // Cache Q in registers (per lane) to avoid reloading it for every KV token.
|
|
|
+ float qreg[8];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ qreg[i] = (d < headDim) ? q[d] : 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Support headDim up to 256 (<= 8 values per lane)
|
|
|
+ float acc[8];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ acc[i] = 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ float m = -INFINITY;
|
|
|
+ float l = 0.0f;
|
|
|
+
|
|
|
+ for (int kv = 0; kv < effectiveLen; kv++) {
|
|
|
+ const int bidx = kv / blockSize;
|
|
|
+ const int boff = kv % blockSize;
|
|
|
+
|
|
|
+ const __half* kBlock = reinterpret_cast<const __half*>(KBlocksFlat[base + bidx]);
|
|
|
+ const __half* k = kBlock + boff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ dot = fmaf(qreg[i], __half2float(k[d]), dot);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ dot += __shfl_down_sync(0xffffffff, dot, offset);
|
|
|
+ }
|
|
|
+ dot = __shfl_sync(0xffffffff, dot, 0);
|
|
|
+
|
|
|
+ const float score = dot * scale;
|
|
|
+ float alpha = 1.0f;
|
|
|
+ float beta = 0.0f;
|
|
|
+ if (lane == 0) {
|
|
|
+ const float newM = fmaxf(m, score);
|
|
|
+ alpha = __expf(m - newM);
|
|
|
+ beta = __expf(score - newM);
|
|
|
+ m = newM;
|
|
|
+ l = l * alpha + beta;
|
|
|
+ }
|
|
|
+ alpha = __shfl_sync(0xffffffff, alpha, 0);
|
|
|
+ beta = __shfl_sync(0xffffffff, beta, 0);
|
|
|
+ l = __shfl_sync(0xffffffff, l, 0);
|
|
|
+
|
|
|
+ const __half* vBlock = reinterpret_cast<const __half*>(VBlocksFlat[base + bidx]);
|
|
|
+ const __half* v = vBlock + boff * kvStride + kvHead * headDim;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ o[d] = acc[i] * invL;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Fused RoPE + paged attention (batch, f16 KV). Expects un-rotated Q/K.
|
|
|
+__global__ void paged_attention_batch_kernel_f16kv_rope(
|
|
|
+ const float* Q,
|
|
|
+ const unsigned short* const* KBlocksFlat,
|
|
|
+ const unsigned short* const* VBlocksFlat,
|
|
|
+ const int* blockOffsets,
|
|
|
+ const int* kvLens,
|
|
|
+ const int* queryPos,
|
|
|
+ float* out,
|
|
|
+ int numTokens,
|
|
|
+ int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale
|
|
|
+) {
|
|
|
+ const int tok = blockIdx.x;
|
|
|
+ const int head = blockIdx.y;
|
|
|
+ const int lane = threadIdx.x & 31;
|
|
|
+ if (tok >= numTokens) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // One warp per (tok, head)
|
|
|
+ if (threadIdx.x >= 32) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int kvHead = head / (numHeads / numKVHeads);
|
|
|
+ const float* q = Q + tok * numHeads * headDim + head * headDim;
|
|
|
+ float* o = out + tok * numHeads * headDim + head * headDim;
|
|
|
+
|
|
|
+ const int kvLen = kvLens[tok];
|
|
|
+ const int qPos = queryPos[tok];
|
|
|
+ const int base = blockOffsets[tok];
|
|
|
+
|
|
|
+ const int kvStride = numKVHeads * headDim;
|
|
|
+ const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
|
|
|
+ const int halfDim = headDim >> 1;
|
|
|
+
|
|
|
+ // Cache Q pairs + per-dim RoPE phase for delta=(qPos - kv) with kv starting at 0.
|
|
|
+ float q0[4];
|
|
|
+ float q1[4];
|
|
|
+ float cosStep[4];
|
|
|
+ float sinStep[4];
|
|
|
+ float cosDelta[4];
|
|
|
+ float sinDelta[4];
|
|
|
+ int pairCount = 0;
|
|
|
+ for (int j = lane; j < halfDim; j += 32) {
|
|
|
+ q0[pairCount] = q[j];
|
|
|
+ q1[pairCount] = q[j + halfDim];
|
|
|
+ cosStep[pairCount] = rope_cos_step_const[j];
|
|
|
+ sinStep[pairCount] = rope_sin_step_const[j];
|
|
|
+ const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
|
|
|
+ const float2 ph = complex_pow_int(baseStep, qPos);
|
|
|
+ cosDelta[pairCount] = ph.x;
|
|
|
+ sinDelta[pairCount] = ph.y;
|
|
|
+ pairCount++;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Support headDim up to 256.
|
|
|
+ float acc[8];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ acc[i] = 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ float m = -INFINITY;
|
|
|
+ float l = 0.0f;
|
|
|
+
|
|
|
+ for (int kv = 0; kv < effectiveLen; kv++) {
|
|
|
+ const int bidx = kv / blockSize;
|
|
|
+ const int boff = kv % blockSize;
|
|
|
+
|
|
|
+ const __half* kBlock = reinterpret_cast<const __half*>(KBlocksFlat[base + bidx]);
|
|
|
+ const __half* k = kBlock + boff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ #pragma unroll
|
|
|
+ for (int pi = 0; pi < 4; pi++) {
|
|
|
+ if (pi >= pairCount) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ const int j = lane + 32 * pi;
|
|
|
+ if (j >= halfDim) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const float k0 = __half2float(k[j]);
|
|
|
+ const float k1 = __half2float(k[j + halfDim]);
|
|
|
+ const float a = fmaf(q0[pi], k0, q1[pi] * k1); // q0*k0 + q1*k1
|
|
|
+ const float b = fmaf(q0[pi], k1, -q1[pi] * k0); // q0*k1 - q1*k0
|
|
|
+ dot = fmaf(cosDelta[pi], a, dot);
|
|
|
+ dot = fmaf(sinDelta[pi], b, dot);
|
|
|
+ }
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ dot += __shfl_down_sync(0xffffffff, dot, offset);
|
|
|
+ }
|
|
|
+ dot = __shfl_sync(0xffffffff, dot, 0);
|
|
|
+
|
|
|
+ // Advance delta -> delta-1 for next kv.
|
|
|
+ #pragma unroll
|
|
|
+ for (int pi = 0; pi < 4; pi++) {
|
|
|
+ if (pi >= pairCount) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
|
|
|
+ }
|
|
|
+
|
|
|
+ const float score = dot * scale;
|
|
|
+ float alpha = 1.0f;
|
|
|
+ float beta = 0.0f;
|
|
|
+ if (lane == 0) {
|
|
|
+ const float newM = fmaxf(m, score);
|
|
|
+ alpha = __expf(m - newM);
|
|
|
+ beta = __expf(score - newM);
|
|
|
+ m = newM;
|
|
|
+ l = l * alpha + beta;
|
|
|
+ }
|
|
|
+ alpha = __shfl_sync(0xffffffff, alpha, 0);
|
|
|
+ beta = __shfl_sync(0xffffffff, beta, 0);
|
|
|
+ l = __shfl_sync(0xffffffff, l, 0);
|
|
|
+
|
|
|
+ const __half* vBlock = reinterpret_cast<const __half*>(VBlocksFlat[base + bidx]);
|
|
|
+ const __half* v = vBlock + boff * kvStride + kvHead * headDim;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ o[d] = acc[i] * invL;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+__global__ void cast_f32_to_f16_kernel(const float* src, __half* dst, int n) {
|
|
|
+ int i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
+ if (i < n) {
|
|
|
+ dst[i] = __float2half_rn(src[i]);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+__global__ void paged_attention_kernel_f16kv(
|
|
|
+ const float* Q,
|
|
|
+ const unsigned short* const* KBlocks,
|
|
|
+ const unsigned short* const* VBlocks,
|
|
|
+ float* out,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale, int startPos
|
|
|
+) {
|
|
|
+ const int seq = blockIdx.x;
|
|
|
+ const int head = blockIdx.y;
|
|
|
+ const int lane = threadIdx.x & 31;
|
|
|
+ if (seq >= seqLen) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // One warp per (seq, head)
|
|
|
+ if (threadIdx.x >= 32) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int kvHead = head / (numHeads / numKVHeads);
|
|
|
+ const float* q = Q + seq * numHeads * headDim + head * headDim;
|
|
|
+ float* o = out + seq * numHeads * headDim + head * headDim;
|
|
|
+
|
|
|
+ const int kvStride = numKVHeads * headDim;
|
|
|
+ const int queryPos = startPos + seq;
|
|
|
+ const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
|
|
|
+
|
|
|
+ // Cache Q in registers (per lane) to avoid reloading it for every KV token.
|
|
|
+ float qreg[8];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ qreg[i] = (d < headDim) ? q[d] : 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ float acc[8];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ acc[i] = 0.0f;
|
|
|
+ }
|
|
|
+ float m = -INFINITY;
|
|
|
+ float l = 0.0f;
|
|
|
+
|
|
|
+ for (int kv = 0; kv < effectiveLen; kv++) {
|
|
|
+ const int blockIdxKV = kv / blockSize;
|
|
|
+ const int blockOff = kv % blockSize;
|
|
|
+
|
|
|
+ const __half* kBlock = reinterpret_cast<const __half*>(KBlocks[blockIdxKV]);
|
|
|
+ const __half* k = kBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ dot = fmaf(qreg[i], __half2float(k[d]), dot);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ dot += __shfl_down_sync(0xffffffff, dot, offset);
|
|
|
+ }
|
|
|
+ dot = __shfl_sync(0xffffffff, dot, 0);
|
|
|
+
|
|
|
+ const float score = dot * scale;
|
|
|
+ float alpha = 1.0f;
|
|
|
+ float beta = 0.0f;
|
|
|
+ if (lane == 0) {
|
|
|
+ const float newM = fmaxf(m, score);
|
|
|
+ alpha = __expf(m - newM);
|
|
|
+ beta = __expf(score - newM);
|
|
|
+ m = newM;
|
|
|
+ l = l * alpha + beta;
|
|
|
+ }
|
|
|
+ alpha = __shfl_sync(0xffffffff, alpha, 0);
|
|
|
+ beta = __shfl_sync(0xffffffff, beta, 0);
|
|
|
+ l = __shfl_sync(0xffffffff, l, 0);
|
|
|
+
|
|
|
+ const __half* vBlock = reinterpret_cast<const __half*>(VBlocks[blockIdxKV]);
|
|
|
+ const __half* v = vBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ o[d] = acc[i] * invL;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Fused RoPE + paged attention (single, f16 KV). Expects un-rotated Q/K.
|
|
|
+__global__ void paged_attention_kernel_f16kv_rope(
|
|
|
+ const float* Q,
|
|
|
+ const unsigned short* const* KBlocks,
|
|
|
+ const unsigned short* const* VBlocks,
|
|
|
+ float* out,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale, int startPos
|
|
|
+) {
|
|
|
+ const int seq = blockIdx.x;
|
|
|
+ const int head = blockIdx.y;
|
|
|
+ const int lane = threadIdx.x & 31;
|
|
|
+ if (seq >= seqLen) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // One warp per (seq, head)
|
|
|
+ if (threadIdx.x >= 32) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int kvHead = head / (numHeads / numKVHeads);
|
|
|
+ const float* q = Q + seq * numHeads * headDim + head * headDim;
|
|
|
+ float* o = out + seq * numHeads * headDim + head * headDim;
|
|
|
+
|
|
|
+ const int kvStride = numKVHeads * headDim;
|
|
|
+ const int queryPos = startPos + seq;
|
|
|
+ const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
|
|
|
+ const int halfDim = headDim >> 1;
|
|
|
+
|
|
|
+ float q0[4];
|
|
|
+ float q1[4];
|
|
|
+ float cosStep[4];
|
|
|
+ float sinStep[4];
|
|
|
+ float cosDelta[4];
|
|
|
+ float sinDelta[4];
|
|
|
+ int pairCount = 0;
|
|
|
+ for (int j = lane; j < halfDim; j += 32) {
|
|
|
+ q0[pairCount] = q[j];
|
|
|
+ q1[pairCount] = q[j + halfDim];
|
|
|
+ cosStep[pairCount] = rope_cos_step_const[j];
|
|
|
+ sinStep[pairCount] = rope_sin_step_const[j];
|
|
|
+ const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
|
|
|
+ const float2 ph = complex_pow_int(baseStep, queryPos);
|
|
|
+ cosDelta[pairCount] = ph.x;
|
|
|
+ sinDelta[pairCount] = ph.y;
|
|
|
+ pairCount++;
|
|
|
+ }
|
|
|
+
|
|
|
+ float acc[8];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ acc[i] = 0.0f;
|
|
|
+ }
|
|
|
+ float m = -INFINITY;
|
|
|
+ float l = 0.0f;
|
|
|
+
|
|
|
+ for (int kv = 0; kv < effectiveLen; kv++) {
|
|
|
+ const int blockIdxKV = kv / blockSize;
|
|
|
+ const int blockOff = kv % blockSize;
|
|
|
+
|
|
|
+ const __half* kBlock = reinterpret_cast<const __half*>(KBlocks[blockIdxKV]);
|
|
|
+ const __half* k = kBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ #pragma unroll
|
|
|
+ for (int pi = 0; pi < 4; pi++) {
|
|
|
+ if (pi >= pairCount) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ const int j = lane + 32 * pi;
|
|
|
+ if (j >= halfDim) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const float k0 = __half2float(k[j]);
|
|
|
+ const float k1 = __half2float(k[j + halfDim]);
|
|
|
+ const float a = fmaf(q0[pi], k0, q1[pi] * k1);
|
|
|
+ const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
|
|
|
+ dot = fmaf(cosDelta[pi], a, dot);
|
|
|
+ dot = fmaf(sinDelta[pi], b, dot);
|
|
|
+ }
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ dot += __shfl_down_sync(0xffffffff, dot, offset);
|
|
|
+ }
|
|
|
+ dot = __shfl_sync(0xffffffff, dot, 0);
|
|
|
+
|
|
|
+ // Advance delta -> delta-1 for next kv.
|
|
|
+ #pragma unroll
|
|
|
+ for (int pi = 0; pi < 4; pi++) {
|
|
|
+ if (pi >= pairCount) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
|
|
|
+ }
|
|
|
+
|
|
|
+ const float score = dot * scale;
|
|
|
+ float alpha = 1.0f;
|
|
|
+ float beta = 0.0f;
|
|
|
+ if (lane == 0) {
|
|
|
+ const float newM = fmaxf(m, score);
|
|
|
+ alpha = __expf(m - newM);
|
|
|
+ beta = __expf(score - newM);
|
|
|
+ m = newM;
|
|
|
+ l = l * alpha + beta;
|
|
|
+ }
|
|
|
+ alpha = __shfl_sync(0xffffffff, alpha, 0);
|
|
|
+ beta = __shfl_sync(0xffffffff, beta, 0);
|
|
|
+ l = __shfl_sync(0xffffffff, l, 0);
|
|
|
+
|
|
|
+ const __half* vBlock = reinterpret_cast<const __half*>(VBlocks[blockIdxKV]);
|
|
|
+ const __half* v = vBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ acc[i] = acc[i] * alpha + beta * __half2float(v[d]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ o[d] = acc[i] * invL;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+template <typename T>
|
|
|
+__device__ __forceinline__ float load_kv(const T* p, int idx) {
|
|
|
+ return p[idx];
|
|
|
+}
|
|
|
+
|
|
|
+template <>
|
|
|
+__device__ __forceinline__ float load_kv<__half>(const __half* p, int idx) {
|
|
|
+ return __half2float(p[idx]);
|
|
|
+}
|
|
|
+
|
|
|
+template <typename T, bool kUseRoPE = false>
|
|
|
+__global__ void paged_attention_split_kv_kernel(
|
|
|
+ const float* Q,
|
|
|
+ const T* const* KBlocks,
|
|
|
+ const T* const* VBlocks,
|
|
|
+ float* partialMax,
|
|
|
+ float* partialSum,
|
|
|
+ float* partialOut,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale, int startPos,
|
|
|
+ int numSplits, int splitSize
|
|
|
+) {
|
|
|
+ const int seq = blockIdx.x;
|
|
|
+ const int head = blockIdx.y;
|
|
|
+ const int split = blockIdx.z;
|
|
|
+ const int lane = threadIdx.x & 31;
|
|
|
+ if (seq >= seqLen) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // One warp per (seq, head, split).
|
|
|
+ if (threadIdx.x >= 32) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int kvHead = head / (numHeads / numKVHeads);
|
|
|
+ const float* q = Q + seq * numHeads * headDim + head * headDim;
|
|
|
+
|
|
|
+ const int kvStride = numKVHeads * headDim;
|
|
|
+ const int queryPos = startPos + seq;
|
|
|
+ const int effectiveLen = (kvLen < (queryPos + 1)) ? kvLen : (queryPos + 1);
|
|
|
+
|
|
|
+ const int splitStart = split * splitSize;
|
|
|
+ const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen;
|
|
|
+
|
|
|
+ const size_t qh = (size_t)seq * (size_t)numHeads + (size_t)head;
|
|
|
+ const size_t splitIdx = qh * (size_t)numSplits + (size_t)split;
|
|
|
+ float* outVec = partialOut + splitIdx * (size_t)headDim;
|
|
|
+
|
|
|
+ if (splitStart >= splitEnd) {
|
|
|
+ if (lane == 0) {
|
|
|
+ partialMax[splitIdx] = -INFINITY;
|
|
|
+ partialSum[splitIdx] = 0.0f;
|
|
|
+ }
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ outVec[d] = 0.0f;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int halfDim = headDim >> 1;
|
|
|
+ const int ropeExp = queryPos - splitStart;
|
|
|
+
|
|
|
+ float qreg[8];
|
|
|
+ float q0[4];
|
|
|
+ float q1[4];
|
|
|
+ float cosStep[4];
|
|
|
+ float sinStep[4];
|
|
|
+ float cosDelta[4];
|
|
|
+ float sinDelta[4];
|
|
|
+ int pairCount = 0;
|
|
|
+ if (kUseRoPE) {
|
|
|
+ for (int j = lane; j < halfDim; j += 32) {
|
|
|
+ q0[pairCount] = q[j];
|
|
|
+ q1[pairCount] = q[j + halfDim];
|
|
|
+ cosStep[pairCount] = rope_cos_step_const[j];
|
|
|
+ sinStep[pairCount] = rope_sin_step_const[j];
|
|
|
+ const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
|
|
|
+ const float2 ph = complex_pow_int(baseStep, ropeExp);
|
|
|
+ cosDelta[pairCount] = ph.x;
|
|
|
+ sinDelta[pairCount] = ph.y;
|
|
|
+ pairCount++;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ qreg[i] = (d < headDim) ? q[d] : 0.0f;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ float acc[8];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ acc[i] = 0.0f;
|
|
|
+ }
|
|
|
+ float m = -INFINITY;
|
|
|
+ float l = 0.0f;
|
|
|
+
|
|
|
+ for (int kv = splitStart; kv < splitEnd; kv++) {
|
|
|
+ const int blockIdxKV = kv / blockSize;
|
|
|
+ const int blockOff = kv % blockSize;
|
|
|
+
|
|
|
+ const T* kBlock = KBlocks[blockIdxKV];
|
|
|
+ const T* k = kBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ if (kUseRoPE) {
|
|
|
+ #pragma unroll
|
|
|
+ for (int pi = 0; pi < 4; pi++) {
|
|
|
+ if (pi >= pairCount) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ const int j = lane + 32 * pi;
|
|
|
+ if (j >= halfDim) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const float k0 = load_kv<T>(k, j);
|
|
|
+ const float k1 = load_kv<T>(k, j + halfDim);
|
|
|
+ const float a = fmaf(q0[pi], k0, q1[pi] * k1);
|
|
|
+ const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
|
|
|
+ dot = fmaf(cosDelta[pi], a, dot);
|
|
|
+ dot = fmaf(sinDelta[pi], b, dot);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ dot = fmaf(qreg[i], load_kv<T>(k, d), dot);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ dot += __shfl_down_sync(0xffffffff, dot, offset);
|
|
|
+ }
|
|
|
+ dot = __shfl_sync(0xffffffff, dot, 0);
|
|
|
+
|
|
|
+ if (kUseRoPE) {
|
|
|
+ #pragma unroll
|
|
|
+ for (int pi = 0; pi < 4; pi++) {
|
|
|
+ if (pi >= pairCount) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const float score = dot * scale;
|
|
|
+ float alpha = 1.0f;
|
|
|
+ float beta = 0.0f;
|
|
|
+ if (lane == 0) {
|
|
|
+ const float newM = fmaxf(m, score);
|
|
|
+ alpha = __expf(m - newM);
|
|
|
+ beta = __expf(score - newM);
|
|
|
+ m = newM;
|
|
|
+ l = l * alpha + beta;
|
|
|
+ }
|
|
|
+ alpha = __shfl_sync(0xffffffff, alpha, 0);
|
|
|
+ beta = __shfl_sync(0xffffffff, beta, 0);
|
|
|
+
|
|
|
+ const T* vBlock = VBlocks[blockIdxKV];
|
|
|
+ const T* v = vBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ acc[i] = acc[i] * alpha + beta * load_kv<T>(v, d);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (lane == 0) {
|
|
|
+ partialMax[splitIdx] = m;
|
|
|
+ partialSum[splitIdx] = l;
|
|
|
+ }
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ outVec[d] = acc[i];
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+template <typename T, bool kUseRoPE = false>
|
|
|
+__global__ void paged_attention_split_kv_batch_kernel(
|
|
|
+ const float* Q,
|
|
|
+ const T* const* KBlocksFlat,
|
|
|
+ const T* const* VBlocksFlat,
|
|
|
+ const int* blockOffsets,
|
|
|
+ const int* kvLens,
|
|
|
+ const int* queryPos,
|
|
|
+ float* partialMax,
|
|
|
+ float* partialSum,
|
|
|
+ float* partialOut,
|
|
|
+ int numTokens,
|
|
|
+ int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale,
|
|
|
+ int numSplits, int splitSize
|
|
|
+) {
|
|
|
+ const int tok = blockIdx.x;
|
|
|
+ const int head = blockIdx.y;
|
|
|
+ const int split = blockIdx.z;
|
|
|
+ const int lane = threadIdx.x & 31;
|
|
|
+ if (tok >= numTokens) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // One warp per (tok, head, split).
|
|
|
+ if (threadIdx.x >= 32) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int kvHead = head / (numHeads / numKVHeads);
|
|
|
+ const float* q = Q + tok * numHeads * headDim + head * headDim;
|
|
|
+
|
|
|
+ const int kvLen = kvLens[tok];
|
|
|
+ const int qPos = queryPos[tok];
|
|
|
+ const int base = blockOffsets[tok];
|
|
|
+
|
|
|
+ const int kvStride = numKVHeads * headDim;
|
|
|
+ const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
|
|
|
+
|
|
|
+ const int splitStart = split * splitSize;
|
|
|
+ const int splitEnd = (splitStart + splitSize < effectiveLen) ? (splitStart + splitSize) : effectiveLen;
|
|
|
+
|
|
|
+ const size_t qh = (size_t)tok * (size_t)numHeads + (size_t)head;
|
|
|
+ const size_t splitIdx = qh * (size_t)numSplits + (size_t)split;
|
|
|
+ float* outVec = partialOut + splitIdx * (size_t)headDim;
|
|
|
+
|
|
|
+ if (splitStart >= splitEnd) {
|
|
|
+ if (lane == 0) {
|
|
|
+ partialMax[splitIdx] = -INFINITY;
|
|
|
+ partialSum[splitIdx] = 0.0f;
|
|
|
+ }
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ outVec[d] = 0.0f;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int halfDim = headDim >> 1;
|
|
|
+ const int ropeExp = qPos - splitStart;
|
|
|
+
|
|
|
+ float qreg[8];
|
|
|
+ float q0[4];
|
|
|
+ float q1[4];
|
|
|
+ float cosStep[4];
|
|
|
+ float sinStep[4];
|
|
|
+ float cosDelta[4];
|
|
|
+ float sinDelta[4];
|
|
|
+ int pairCount = 0;
|
|
|
+ if (kUseRoPE) {
|
|
|
+ for (int j = lane; j < halfDim; j += 32) {
|
|
|
+ q0[pairCount] = q[j];
|
|
|
+ q1[pairCount] = q[j + halfDim];
|
|
|
+ cosStep[pairCount] = rope_cos_step_const[j];
|
|
|
+ sinStep[pairCount] = rope_sin_step_const[j];
|
|
|
+ const float2 baseStep = make_float2(cosStep[pairCount], sinStep[pairCount]);
|
|
|
+ const float2 ph = complex_pow_int(baseStep, ropeExp);
|
|
|
+ cosDelta[pairCount] = ph.x;
|
|
|
+ sinDelta[pairCount] = ph.y;
|
|
|
+ pairCount++;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ qreg[i] = (d < headDim) ? q[d] : 0.0f;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ float acc[8];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ acc[i] = 0.0f;
|
|
|
+ }
|
|
|
+ float m = -INFINITY;
|
|
|
+ float l = 0.0f;
|
|
|
+
|
|
|
+ for (int kv = splitStart; kv < splitEnd; kv++) {
|
|
|
+ const int bidx = kv / blockSize;
|
|
|
+ const int boff = kv % blockSize;
|
|
|
+
|
|
|
+ const T* kBlock = KBlocksFlat[base + bidx];
|
|
|
+ const T* k = kBlock + boff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ if (kUseRoPE) {
|
|
|
+ #pragma unroll
|
|
|
+ for (int pi = 0; pi < 4; pi++) {
|
|
|
+ if (pi >= pairCount) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ const int j = lane + 32 * pi;
|
|
|
+ if (j >= halfDim) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const float k0 = load_kv<T>(k, j);
|
|
|
+ const float k1 = load_kv<T>(k, j + halfDim);
|
|
|
+ const float a = fmaf(q0[pi], k0, q1[pi] * k1);
|
|
|
+ const float b = fmaf(q0[pi], k1, -q1[pi] * k0);
|
|
|
+ dot = fmaf(cosDelta[pi], a, dot);
|
|
|
+ dot = fmaf(sinDelta[pi], b, dot);
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ dot = fmaf(qreg[i], load_kv<T>(k, d), dot);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ dot += __shfl_down_sync(0xffffffff, dot, offset);
|
|
|
+ }
|
|
|
+ dot = __shfl_sync(0xffffffff, dot, 0);
|
|
|
+
|
|
|
+ if (kUseRoPE) {
|
|
|
+ #pragma unroll
|
|
|
+ for (int pi = 0; pi < 4; pi++) {
|
|
|
+ if (pi >= pairCount) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ rope_advance_neg(cosDelta[pi], sinDelta[pi], cosStep[pi], sinStep[pi]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const float score = dot * scale;
|
|
|
+ float alpha = 1.0f;
|
|
|
+ float beta = 0.0f;
|
|
|
+ if (lane == 0) {
|
|
|
+ const float newM = fmaxf(m, score);
|
|
|
+ alpha = __expf(m - newM);
|
|
|
+ beta = __expf(score - newM);
|
|
|
+ m = newM;
|
|
|
+ l = l * alpha + beta;
|
|
|
+ }
|
|
|
+ alpha = __shfl_sync(0xffffffff, alpha, 0);
|
|
|
+ beta = __shfl_sync(0xffffffff, beta, 0);
|
|
|
+
|
|
|
+ const T* vBlock = VBlocksFlat[base + bidx];
|
|
|
+ const T* v = vBlock + boff * kvStride + kvHead * headDim;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ acc[i] = acc[i] * alpha + beta * load_kv<T>(v, d);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if (lane == 0) {
|
|
|
+ partialMax[splitIdx] = m;
|
|
|
+ partialSum[splitIdx] = l;
|
|
|
+ }
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ outVec[d] = acc[i];
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+__global__ void paged_attention_split_kv_reduce_kernel(
|
|
|
+ const float* partialMax,
|
|
|
+ const float* partialSum,
|
|
|
+ const float* partialOut,
|
|
|
+ float* out,
|
|
|
+ int queryCount,
|
|
|
+ int numHeads,
|
|
|
+ int headDim,
|
|
|
+ int numSplits
|
|
|
+) {
|
|
|
+ const int q = blockIdx.x;
|
|
|
+ const int head = blockIdx.y;
|
|
|
+ const int lane = threadIdx.x & 31;
|
|
|
+ if (q >= queryCount) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // One warp per (q, head).
|
|
|
+ if (threadIdx.x >= 32) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ const size_t qh = (size_t)q * (size_t)numHeads + (size_t)head;
|
|
|
+ const size_t base = qh * (size_t)numSplits;
|
|
|
+
|
|
|
+ float gmax = -INFINITY;
|
|
|
+ if (lane == 0) {
|
|
|
+ for (int s = 0; s < numSplits; s++) {
|
|
|
+ gmax = fmaxf(gmax, partialMax[base + (size_t)s]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ gmax = __shfl_sync(0xffffffff, gmax, 0);
|
|
|
+ if (gmax == -INFINITY) {
|
|
|
+ const size_t outBase = qh * (size_t)headDim;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ out[outBase + (size_t)d] = 0.0f;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ float sum = 0.0f;
|
|
|
+ if (lane == 0) {
|
|
|
+ for (int s = 0; s < numSplits; s++) {
|
|
|
+ const float l = partialSum[base + (size_t)s];
|
|
|
+ sum += l * __expf(partialMax[base + (size_t)s] - gmax);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ sum = __shfl_sync(0xffffffff, sum, 0);
|
|
|
+ const float invSum = (sum > 0.0f) ? (1.0f / sum) : 0.0f;
|
|
|
+
|
|
|
+ float acc[8];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ acc[i] = 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int s = 0; s < numSplits; s++) {
|
|
|
+ float scale = 0.0f;
|
|
|
+ if (lane == 0) {
|
|
|
+ scale = __expf(partialMax[base + (size_t)s] - gmax);
|
|
|
+ }
|
|
|
+ scale = __shfl_sync(0xffffffff, scale, 0);
|
|
|
+
|
|
|
+ const float* vec = partialOut + (base + (size_t)s) * (size_t)headDim;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ acc[i] = fmaf(scale, vec[d], acc[i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ const size_t outBase = qh * (size_t)headDim;
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ const int d = lane + 32 * i;
|
|
|
+ if (d < headDim) {
|
|
|
+ out[outBase + (size_t)d] = acc[i] * invSum;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_rmsnorm_f32(float* x, const float* w, int seqLen, int dim, float eps) {
|
|
|
+ int threads = 256;
|
|
|
+ rmsnorm_kernel<<<seqLen, threads>>>(x, w, dim, eps);
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_cast_f32_to_f16(const float* src, unsigned short* dst, int n) {
|
|
|
+ int threads = 256;
|
|
|
+ int blocks = (n + threads - 1) / threads;
|
|
|
+ cast_f32_to_f16_kernel<<<blocks, threads>>>(src, reinterpret_cast<__half*>(dst), n);
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_paged_attention_f32_f16kv(
|
|
|
+ const float* Q,
|
|
|
+ const unsigned short* const* KBlocks,
|
|
|
+ const unsigned short* const* VBlocks,
|
|
|
+ float* out,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale, int startPos
|
|
|
+) {
|
|
|
+ // Split-KV Flash Decoding for long contexts.
|
|
|
+ const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
|
|
|
+ const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
|
|
|
+ const int qhCount = seqLen * numHeads;
|
|
|
+ const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
|
|
|
+
|
|
|
+ if (!useSplit) {
|
|
|
+ dim3 blocks(seqLen, numHeads);
|
|
|
+ int threads = 32;
|
|
|
+ paged_attention_kernel_f16kv<<<blocks, threads>>>(
|
|
|
+ Q, KBlocks, VBlocks, out,
|
|
|
+ seqLen, kvLen, numHeads, numKVHeads, headDim,
|
|
|
+ blockSize, scale, startPos
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
|
|
|
+ const size_t totalFloats = splitCount * (size_t)(headDim + 2);
|
|
|
+ float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
|
|
|
+ if (buf == nullptr) {
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ float* partialMax = buf;
|
|
|
+ float* partialSum = partialMax + splitCount;
|
|
|
+ float* partialOut = partialSum + splitCount;
|
|
|
+
|
|
|
+ dim3 blocks1(seqLen, numHeads, numSplits);
|
|
|
+ paged_attention_split_kv_kernel<__half><<<blocks1, 32>>>(
|
|
|
+ Q,
|
|
|
+ reinterpret_cast<const __half* const*>(KBlocks),
|
|
|
+ reinterpret_cast<const __half* const*>(VBlocks),
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ seqLen, kvLen, numHeads, numKVHeads, headDim,
|
|
|
+ blockSize,
|
|
|
+ scale, startPos,
|
|
|
+ numSplits, kPagedAttentionSplitSize
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ dim3 blocks2(seqLen, numHeads);
|
|
|
+ paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ out,
|
|
|
+ seqLen, numHeads, headDim, numSplits
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ cuda_free(buf);
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_paged_attention_rope_f32_f16kv(
|
|
|
+ const float* Q,
|
|
|
+ const unsigned short* const* KBlocks,
|
|
|
+ const unsigned short* const* VBlocks,
|
|
|
+ float* out,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale, int startPos,
|
|
|
+ float theta
|
|
|
+) {
|
|
|
+ if (ensure_rope_step_table(headDim, theta) != 0) {
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Split-KV Flash Decoding for long contexts.
|
|
|
+ const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
|
|
|
+ const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
|
|
|
+ const int qhCount = seqLen * numHeads;
|
|
|
+ const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
|
|
|
+
|
|
|
+ if (!useSplit) {
|
|
|
+ dim3 blocks(seqLen, numHeads);
|
|
|
+ int threads = 32;
|
|
|
+ paged_attention_kernel_f16kv_rope<<<blocks, threads>>>(
|
|
|
+ Q, KBlocks, VBlocks, out,
|
|
|
+ seqLen, kvLen, numHeads, numKVHeads, headDim,
|
|
|
+ blockSize, scale, startPos
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
|
|
|
+ const size_t totalFloats = splitCount * (size_t)(headDim + 2);
|
|
|
+ float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
|
|
|
+ if (buf == nullptr) {
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ float* partialMax = buf;
|
|
|
+ float* partialSum = partialMax + splitCount;
|
|
|
+ float* partialOut = partialSum + splitCount;
|
|
|
+
|
|
|
+ dim3 blocks1(seqLen, numHeads, numSplits);
|
|
|
+ paged_attention_split_kv_kernel<__half, true><<<blocks1, 32>>>(
|
|
|
+ Q,
|
|
|
+ reinterpret_cast<const __half* const*>(KBlocks),
|
|
|
+ reinterpret_cast<const __half* const*>(VBlocks),
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ seqLen, kvLen, numHeads, numKVHeads, headDim,
|
|
|
+ blockSize,
|
|
|
+ scale, startPos,
|
|
|
+ numSplits, kPagedAttentionSplitSize
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ dim3 blocks2(seqLen, numHeads);
|
|
|
+ paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ out,
|
|
|
+ seqLen, numHeads, headDim, numSplits
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ cuda_free(buf);
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_paged_attention_batch_f32_f16kv(
|
|
|
+ const float* Q,
|
|
|
+ const unsigned short* const* KBlocksFlat,
|
|
|
+ const unsigned short* const* VBlocksFlat,
|
|
|
+ const int* blockOffsets,
|
|
|
+ const int* kvLens,
|
|
|
+ const int* queryPos,
|
|
|
+ float* out,
|
|
|
+ int numTokens,
|
|
|
+ int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale,
|
|
|
+ int maxKvLen
|
|
|
+) {
|
|
|
+ const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
|
|
|
+ const int qhCount = numTokens * numHeads;
|
|
|
+ const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
|
|
|
+
|
|
|
+ if (!useSplit) {
|
|
|
+ dim3 blocks(numTokens, numHeads);
|
|
|
+ int threads = 32;
|
|
|
+ paged_attention_batch_kernel_f16kv<<<blocks, threads>>>(
|
|
|
+ Q, KBlocksFlat, VBlocksFlat,
|
|
|
+ blockOffsets, kvLens, queryPos,
|
|
|
+ out,
|
|
|
+ numTokens,
|
|
|
+ numHeads, numKVHeads, headDim,
|
|
|
+ blockSize,
|
|
|
+ scale
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
|
|
|
+ const size_t totalFloats = splitCount * (size_t)(headDim + 2);
|
|
|
+ float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
|
|
|
+ if (buf == nullptr) {
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ float* partialMax = buf;
|
|
|
+ float* partialSum = partialMax + splitCount;
|
|
|
+ float* partialOut = partialSum + splitCount;
|
|
|
+
|
|
|
+ dim3 blocks1(numTokens, numHeads, numSplits);
|
|
|
+ paged_attention_split_kv_batch_kernel<__half><<<blocks1, 32>>>(
|
|
|
+ Q,
|
|
|
+ reinterpret_cast<const __half* const*>(KBlocksFlat),
|
|
|
+ reinterpret_cast<const __half* const*>(VBlocksFlat),
|
|
|
+ blockOffsets, kvLens, queryPos,
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ numTokens,
|
|
|
+ numHeads, numKVHeads, headDim,
|
|
|
+ blockSize,
|
|
|
+ scale,
|
|
|
+ numSplits, kPagedAttentionSplitSize
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ dim3 blocks2(numTokens, numHeads);
|
|
|
+ paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ out,
|
|
|
+ numTokens, numHeads, headDim, numSplits
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ cuda_free(buf);
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_paged_attention_rope_batch_f32_f16kv(
|
|
|
+ const float* Q,
|
|
|
+ const unsigned short* const* KBlocksFlat,
|
|
|
+ const unsigned short* const* VBlocksFlat,
|
|
|
+ const int* blockOffsets,
|
|
|
+ const int* kvLens,
|
|
|
+ const int* queryPos,
|
|
|
+ float* out,
|
|
|
+ int numTokens,
|
|
|
+ int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale,
|
|
|
+ int maxKvLen,
|
|
|
+ float theta
|
|
|
+) {
|
|
|
+ if (ensure_rope_step_table(headDim, theta) != 0) {
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+
|
|
|
+ const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
|
|
|
+ const int qhCount = numTokens * numHeads;
|
|
|
+ const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
|
|
|
+
|
|
|
+ if (!useSplit) {
|
|
|
+ dim3 blocks(numTokens, numHeads);
|
|
|
+ int threads = 32;
|
|
|
+ paged_attention_batch_kernel_f16kv_rope<<<blocks, threads>>>(
|
|
|
+ Q, KBlocksFlat, VBlocksFlat,
|
|
|
+ blockOffsets, kvLens, queryPos,
|
|
|
+ out,
|
|
|
+ numTokens,
|
|
|
+ numHeads, numKVHeads, headDim,
|
|
|
+ blockSize,
|
|
|
+ scale
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
|
|
|
+ const size_t totalFloats = splitCount * (size_t)(headDim + 2);
|
|
|
+ float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
|
|
|
+ if (buf == nullptr) {
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ float* partialMax = buf;
|
|
|
+ float* partialSum = partialMax + splitCount;
|
|
|
+ float* partialOut = partialSum + splitCount;
|
|
|
+
|
|
|
+ dim3 blocks1(numTokens, numHeads, numSplits);
|
|
|
+ paged_attention_split_kv_batch_kernel<__half, true><<<blocks1, 32>>>(
|
|
|
+ Q,
|
|
|
+ reinterpret_cast<const __half* const*>(KBlocksFlat),
|
|
|
+ reinterpret_cast<const __half* const*>(VBlocksFlat),
|
|
|
+ blockOffsets, kvLens, queryPos,
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ numTokens,
|
|
|
+ numHeads, numKVHeads, headDim,
|
|
|
+ blockSize,
|
|
|
+ scale,
|
|
|
+ numSplits, kPagedAttentionSplitSize
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ dim3 blocks2(numTokens, numHeads);
|
|
|
+ paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ out,
|
|
|
+ numTokens, numHeads, headDim, numSplits
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ cuda_free(buf);
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+__global__ void paged_attention_batch_kernel(
|
|
|
+ const float* Q,
|
|
|
+ const float* const* KBlocksFlat,
|
|
|
+ const float* const* VBlocksFlat,
|
|
|
+ const int* blockOffsets,
|
|
|
+ const int* kvLens,
|
|
|
+ const int* queryPos,
|
|
|
+ float* out,
|
|
|
+ int numTokens,
|
|
|
+ int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale
|
|
|
+);
|
|
|
+
|
|
|
+__global__ void paged_attention_batch_kernel_f16kv(
|
|
|
+ const float* Q,
|
|
|
+ const unsigned short* const* KBlocksFlat,
|
|
|
+ const unsigned short* const* VBlocksFlat,
|
|
|
+ const int* blockOffsets,
|
|
|
+ const int* kvLens,
|
|
|
+ const int* queryPos,
|
|
|
+ float* out,
|
|
|
+ int numTokens,
|
|
|
+ int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale
|
|
|
+);
|
|
|
+
|
|
|
+int cuda_paged_attention_batch_f32(
|
|
|
+ const float* Q,
|
|
|
+ const float* const* KBlocksFlat,
|
|
|
+ const float* const* VBlocksFlat,
|
|
|
+ const int* blockOffsets,
|
|
|
+ const int* kvLens,
|
|
|
+ const int* queryPos,
|
|
|
+ float* out,
|
|
|
+ int numTokens,
|
|
|
+ int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale,
|
|
|
+ int maxKvLen
|
|
|
+) {
|
|
|
+ const int numSplits = (maxKvLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
|
|
|
+ const int qhCount = numTokens * numHeads;
|
|
|
+ const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
|
|
|
+
|
|
|
+ if (!useSplit) {
|
|
|
+ dim3 blocks(numTokens, numHeads);
|
|
|
+ int threads = 256;
|
|
|
+ paged_attention_batch_kernel<<<blocks, threads>>>(
|
|
|
+ Q, KBlocksFlat, VBlocksFlat,
|
|
|
+ blockOffsets, kvLens, queryPos,
|
|
|
+ out,
|
|
|
+ numTokens,
|
|
|
+ numHeads, numKVHeads, headDim,
|
|
|
+ blockSize,
|
|
|
+ scale
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
|
|
|
+ const size_t totalFloats = splitCount * (size_t)(headDim + 2);
|
|
|
+ float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
|
|
|
+ if (buf == nullptr) {
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ float* partialMax = buf;
|
|
|
+ float* partialSum = partialMax + splitCount;
|
|
|
+ float* partialOut = partialSum + splitCount;
|
|
|
+
|
|
|
+ dim3 blocks1(numTokens, numHeads, numSplits);
|
|
|
+ paged_attention_split_kv_batch_kernel<float><<<blocks1, 32>>>(
|
|
|
+ Q,
|
|
|
+ KBlocksFlat,
|
|
|
+ VBlocksFlat,
|
|
|
+ blockOffsets, kvLens, queryPos,
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ numTokens,
|
|
|
+ numHeads, numKVHeads, headDim,
|
|
|
+ blockSize,
|
|
|
+ scale,
|
|
|
+ numSplits, kPagedAttentionSplitSize
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ dim3 blocks2(numTokens, numHeads);
|
|
|
+ paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ out,
|
|
|
+ numTokens, numHeads, headDim, numSplits
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ cuda_free(buf);
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+// RoPE kernel
|
|
|
+__global__ void rope_kernel(float* x, const int* positions, int totalDim, int headDim, float theta) {
|
|
|
+ int seq = blockIdx.x;
|
|
|
+ int pos = positions[seq];
|
|
|
+ float* rowData = x + seq * totalDim;
|
|
|
+ int halfDim = headDim / 2;
|
|
|
+
|
|
|
+ // Each thread handles one (j, j+halfDim) pair across all heads.
|
|
|
+ // Compute sin/cos once per j and reuse across heads.
|
|
|
+ for (int j = threadIdx.x; j < halfDim; j += blockDim.x) {
|
|
|
+ const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim);
|
|
|
+ const float freq = pos * invFreq;
|
|
|
+ float sinF, cosF;
|
|
|
+ sincosf(freq, &sinF, &cosF);
|
|
|
+
|
|
|
+ for (int headStart = 0; headStart < totalDim; headStart += headDim) {
|
|
|
+ const int idx0 = headStart + j;
|
|
|
+ const int idx1 = idx0 + halfDim;
|
|
|
+
|
|
|
+ const float v0 = rowData[idx0];
|
|
|
+ const float v1 = rowData[idx1];
|
|
|
+
|
|
|
+ rowData[idx0] = v0 * cosF - v1 * sinF;
|
|
|
+ rowData[idx1] = v1 * cosF + v0 * sinF;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_rope_f32(float* x, const int* positions, int seqLen, int numHeads, int headDim, float theta) {
|
|
|
+ int threads = (headDim / 2) < 128 ? (headDim / 2) : 128;
|
|
|
+ rope_kernel<<<seqLen, threads>>>(x, positions, numHeads * headDim, headDim, theta);
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+// Optimized RoPE kernel for single position (scalar pos)
|
|
|
+__global__ void rope_kernel_single(float* x, int pos, int totalDim, int headDim, float theta) {
|
|
|
+ // seq is always 0 for single token
|
|
|
+ float* rowData = x; // x + 0 * totalDim
|
|
|
+ int halfDim = headDim / 2;
|
|
|
+
|
|
|
+ // Each thread handles one (j, j+halfDim) pair across all heads.
|
|
|
+ // Compute sin/cos once per j and reuse across heads.
|
|
|
+ for (int j = threadIdx.x; j < halfDim; j += blockDim.x) {
|
|
|
+ const float invFreq = 1.0f / powf(theta, 2.0f * j / headDim);
|
|
|
+ const float freq = pos * invFreq;
|
|
|
+ float sinF, cosF;
|
|
|
+ sincosf(freq, &sinF, &cosF);
|
|
|
+
|
|
|
+ for (int headStart = 0; headStart < totalDim; headStart += headDim) {
|
|
|
+ const int idx0 = headStart + j;
|
|
|
+ const int idx1 = idx0 + halfDim;
|
|
|
+
|
|
|
+ const float v0 = rowData[idx0];
|
|
|
+ const float v1 = rowData[idx1];
|
|
|
+
|
|
|
+ rowData[idx0] = v0 * cosF - v1 * sinF;
|
|
|
+ rowData[idx1] = v1 * cosF + v0 * sinF;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_rope_f32_single(float* x, int pos, int numHeads, int headDim, float theta) {
|
|
|
+ int threads = (headDim / 2) < 128 ? (headDim / 2) : 128;
|
|
|
+ rope_kernel_single<<<1, threads>>>(x, pos, numHeads * headDim, headDim, theta);
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+// Softmax kernel: one block per row
|
|
|
+__global__ void softmax_kernel(float* x, int cols) {
|
|
|
+ int row = blockIdx.x;
|
|
|
+ float* rowData = x + row * cols;
|
|
|
+
|
|
|
+ __shared__ float smax[256];
|
|
|
+ __shared__ float ssum[256];
|
|
|
+
|
|
|
+ // Find max
|
|
|
+ float threadMax = -INFINITY;
|
|
|
+ for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
|
|
+ threadMax = fmaxf(threadMax, rowData[i]);
|
|
|
+ }
|
|
|
+ smax[threadIdx.x] = threadMax;
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
|
+ if (threadIdx.x < s) {
|
|
|
+ smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ }
|
|
|
+ float maxVal = smax[0];
|
|
|
+
|
|
|
+ // Compute exp and sum
|
|
|
+ float threadSum = 0.0f;
|
|
|
+ for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
|
|
+ float val = expf(rowData[i] - maxVal);
|
|
|
+ rowData[i] = val;
|
|
|
+ threadSum += val;
|
|
|
+ }
|
|
|
+ ssum[threadIdx.x] = threadSum;
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
|
+ if (threadIdx.x < s) {
|
|
|
+ ssum[threadIdx.x] += ssum[threadIdx.x + s];
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ }
|
|
|
+ float sum = ssum[0];
|
|
|
+
|
|
|
+ // Normalize
|
|
|
+ for (int i = threadIdx.x; i < cols; i += blockDim.x) {
|
|
|
+ rowData[i] /= sum;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_softmax_f32(float* x, int rows, int cols) {
|
|
|
+ int threads = 256;
|
|
|
+ softmax_kernel<<<rows, threads>>>(x, cols);
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+// ============================================================
|
|
|
+// Top-K Logits Selection (for sampling without full D2H)
|
|
|
+// ============================================================
|
|
|
+
|
|
|
+#define TOPK_THREADS 256
|
|
|
+#define TOPK_LOCAL 8
|
|
|
+#define TOPK_SEGMENT (TOPK_THREADS * TOPK_LOCAL) // 2048
|
|
|
+#define TOPK_MAX_K 64
|
|
|
+
|
|
|
+static __device__ __forceinline__ float apply_rep_penalty(float score, bool hit, float penalty) {
|
|
|
+ if (!hit) return score;
|
|
|
+ if (score > 0.0f) return score / penalty;
|
|
|
+ return score * penalty;
|
|
|
+}
|
|
|
+
|
|
|
+__global__ void topk_logits_kernel(
|
|
|
+ const float* logits, int vocab,
|
|
|
+ const int* rep_ids, int rep_count, float rep_penalty,
|
|
|
+ int k,
|
|
|
+ int* out_ids, float* out_scores
|
|
|
+) {
|
|
|
+ const int blockStart = blockIdx.x * TOPK_SEGMENT;
|
|
|
+ const int tid = threadIdx.x;
|
|
|
+ if (tid >= TOPK_THREADS) return;
|
|
|
+
|
|
|
+ float localScores[TOPK_LOCAL];
|
|
|
+ int localIds[TOPK_LOCAL];
|
|
|
+ #pragma unroll
|
|
|
+ for (int i = 0; i < TOPK_LOCAL; i++) {
|
|
|
+ localScores[i] = -INFINITY;
|
|
|
+ localIds[i] = -1;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Each thread processes up to TOPK_LOCAL elements in a contiguous segment.
|
|
|
+ #pragma unroll
|
|
|
+ for (int j = 0; j < TOPK_LOCAL; j++) {
|
|
|
+ int idx = blockStart + tid + j * TOPK_THREADS;
|
|
|
+ if (idx >= vocab) break;
|
|
|
+ float score = logits[idx];
|
|
|
+ bool hit = false;
|
|
|
+ // rep_count is small (<=64). Linear scan is fine.
|
|
|
+ for (int r = 0; r < rep_count; r++) {
|
|
|
+ if (rep_ids[r] == idx) {
|
|
|
+ hit = true;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ score = apply_rep_penalty(score, hit, rep_penalty);
|
|
|
+
|
|
|
+ // Insert into local top list (descending)
|
|
|
+ int pos = TOPK_LOCAL;
|
|
|
+ for (int t = 0; t < TOPK_LOCAL; t++) {
|
|
|
+ if (score > localScores[t]) { pos = t; break; }
|
|
|
+ }
|
|
|
+ if (pos < TOPK_LOCAL) {
|
|
|
+ for (int t = TOPK_LOCAL - 1; t > pos; t--) {
|
|
|
+ localScores[t] = localScores[t-1];
|
|
|
+ localIds[t] = localIds[t-1];
|
|
|
+ }
|
|
|
+ localScores[pos] = score;
|
|
|
+ localIds[pos] = idx;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ __shared__ float shScores[TOPK_SEGMENT];
|
|
|
+ __shared__ int shIds[TOPK_SEGMENT];
|
|
|
+
|
|
|
+ // Write local results to shared candidate pool.
|
|
|
+ #pragma unroll
|
|
|
+ for (int j = 0; j < TOPK_LOCAL; j++) {
|
|
|
+ int out = tid * TOPK_LOCAL + j;
|
|
|
+ shScores[out] = localScores[j];
|
|
|
+ shIds[out] = localIds[j];
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ if (tid == 0) {
|
|
|
+ // Block-level exact top-k from TOPK_SEGMENT candidates.
|
|
|
+ if (k > TOPK_MAX_K) k = TOPK_MAX_K;
|
|
|
+ float bestScores[TOPK_MAX_K];
|
|
|
+ int bestIds[TOPK_MAX_K];
|
|
|
+ for (int i = 0; i < k; i++) {
|
|
|
+ bestScores[i] = -INFINITY;
|
|
|
+ bestIds[i] = -1;
|
|
|
+ }
|
|
|
+ for (int i = 0; i < TOPK_SEGMENT; i++) {
|
|
|
+ float score = shScores[i];
|
|
|
+ int id = shIds[i];
|
|
|
+ if (id < 0) continue;
|
|
|
+ // Insert into best (descending)
|
|
|
+ if (score <= bestScores[k-1]) continue;
|
|
|
+ int pos = k;
|
|
|
+ for (int t = 0; t < k; t++) {
|
|
|
+ if (score > bestScores[t]) { pos = t; break; }
|
|
|
+ }
|
|
|
+ if (pos < k) {
|
|
|
+ for (int t = k - 1; t > pos; t--) {
|
|
|
+ bestScores[t] = bestScores[t-1];
|
|
|
+ bestIds[t] = bestIds[t-1];
|
|
|
+ }
|
|
|
+ bestScores[pos] = score;
|
|
|
+ bestIds[pos] = id;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ int base = blockIdx.x * k;
|
|
|
+ for (int i = 0; i < k; i++) {
|
|
|
+ out_ids[base + i] = bestIds[i];
|
|
|
+ out_scores[base + i] = bestScores[i];
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_topk_logits_f32(
|
|
|
+ const float* logits, int vocab,
|
|
|
+ const int* rep_ids, int rep_count, float rep_penalty,
|
|
|
+ int k,
|
|
|
+ int* out_ids, float* out_scores
|
|
|
+) {
|
|
|
+ if (k <= 0) return 0;
|
|
|
+ if (k > TOPK_MAX_K) k = TOPK_MAX_K;
|
|
|
+ int blocks = (vocab + TOPK_SEGMENT - 1) / TOPK_SEGMENT;
|
|
|
+ dim3 grid(blocks);
|
|
|
+ dim3 block(TOPK_THREADS);
|
|
|
+ topk_logits_kernel<<<grid, block>>>(logits, vocab, rep_ids, rep_count, rep_penalty, k, out_ids, out_scores);
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+// ============================================================
|
|
|
+// Attention Kernel
|
|
|
+// Computes: softmax(Q @ K.T / scale + causal_mask) @ V
|
|
|
+// ============================================================
|
|
|
+__global__ void attention_kernel(
|
|
|
+ const float* Q, const float* K, const float* V, float* out,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ float scale, int startPos
|
|
|
+) {
|
|
|
+ // Each block handles one (seq, head) pair
|
|
|
+ int seq = blockIdx.x;
|
|
|
+ int head = blockIdx.y;
|
|
|
+
|
|
|
+ int kvHead = head / (numHeads / numKVHeads); // GQA support
|
|
|
+
|
|
|
+ const float* q = Q + seq * numHeads * headDim + head * headDim;
|
|
|
+ float* o = out + seq * numHeads * headDim + head * headDim;
|
|
|
+
|
|
|
+ // Shared memory for attention scores
|
|
|
+ extern __shared__ float shared[];
|
|
|
+ float* scores = shared; // [kvLen]
|
|
|
+
|
|
|
+ // Compute Q @ K.T for this head
|
|
|
+ for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
|
|
|
+ const float* k = K + kv * numKVHeads * headDim + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ for (int d = 0; d < headDim; d++) {
|
|
|
+ dot += q[d] * k[d];
|
|
|
+ }
|
|
|
+
|
|
|
+ // Apply causal mask
|
|
|
+ int queryPos = startPos + seq;
|
|
|
+ int keyPos = kv;
|
|
|
+ if (keyPos > queryPos) {
|
|
|
+ dot = -INFINITY;
|
|
|
+ }
|
|
|
+
|
|
|
+ scores[kv] = dot * scale;
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ // Softmax over scores
|
|
|
+ // Find max
|
|
|
+ float maxVal = -INFINITY;
|
|
|
+ for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
|
|
|
+ maxVal = fmaxf(maxVal, scores[kv]);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Reduce max across threads
|
|
|
+ __shared__ float smax[256];
|
|
|
+ smax[threadIdx.x] = maxVal;
|
|
|
+ __syncthreads();
|
|
|
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
|
+ if (threadIdx.x < s) {
|
|
|
+ smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ }
|
|
|
+ maxVal = smax[0];
|
|
|
+
|
|
|
+ // Exp and sum
|
|
|
+ float sum = 0.0f;
|
|
|
+ for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
|
|
|
+ float val = expf(scores[kv] - maxVal);
|
|
|
+ scores[kv] = val;
|
|
|
+ sum += val;
|
|
|
+ }
|
|
|
+
|
|
|
+ __shared__ float ssum[256];
|
|
|
+ ssum[threadIdx.x] = sum;
|
|
|
+ __syncthreads();
|
|
|
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
|
+ if (threadIdx.x < s) {
|
|
|
+ ssum[threadIdx.x] += ssum[threadIdx.x + s];
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ }
|
|
|
+ sum = ssum[0];
|
|
|
+
|
|
|
+ // Normalize
|
|
|
+ for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
|
|
|
+ scores[kv] /= sum;
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ // Compute weighted sum of V
|
|
|
+ for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
|
|
|
+ float val = 0.0f;
|
|
|
+ for (int kv = 0; kv < kvLen; kv++) {
|
|
|
+ const float* v = V + kv * numKVHeads * headDim + kvHead * headDim;
|
|
|
+ val += scores[kv] * v[d];
|
|
|
+ }
|
|
|
+ o[d] = val;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+__global__ void paged_attention_kernel(
|
|
|
+ const float* Q,
|
|
|
+ const float* const* KBlocks,
|
|
|
+ const float* const* VBlocks,
|
|
|
+ float* out,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale, int startPos
|
|
|
+) {
|
|
|
+ int seq = blockIdx.x;
|
|
|
+ int head = blockIdx.y;
|
|
|
+
|
|
|
+ int kvHead = head / (numHeads / numKVHeads);
|
|
|
+
|
|
|
+ const float* q = Q + seq * numHeads * headDim + head * headDim;
|
|
|
+ float* o = out + seq * numHeads * headDim + head * headDim;
|
|
|
+
|
|
|
+ const int kvStride = numKVHeads * headDim;
|
|
|
+
|
|
|
+ int queryPos = startPos + seq;
|
|
|
+
|
|
|
+ // Pass 1: max
|
|
|
+ float localMax = -INFINITY;
|
|
|
+ for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
|
|
|
+ int keyPos = kv;
|
|
|
+ if (keyPos > queryPos) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const int blockIdxKV = kv / blockSize;
|
|
|
+ const int blockOff = kv % blockSize;
|
|
|
+ const float* kBlock = KBlocks[blockIdxKV];
|
|
|
+ const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ for (int d = 0; d < headDim; d++) {
|
|
|
+ dot += q[d] * k[d];
|
|
|
+ }
|
|
|
+ float score = dot * scale;
|
|
|
+ localMax = fmaxf(localMax, score);
|
|
|
+ }
|
|
|
+
|
|
|
+ __shared__ float smax[256];
|
|
|
+ smax[threadIdx.x] = localMax;
|
|
|
+ __syncthreads();
|
|
|
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
|
+ if (threadIdx.x < s) {
|
|
|
+ smax[threadIdx.x] = fmaxf(smax[threadIdx.x], smax[threadIdx.x + s]);
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ }
|
|
|
+ float maxVal = smax[0];
|
|
|
+
|
|
|
+ // Pass 2: sum exp
|
|
|
+ float localSum = 0.0f;
|
|
|
+ for (int kv = threadIdx.x; kv < kvLen; kv += blockDim.x) {
|
|
|
+ int keyPos = kv;
|
|
|
+ if (keyPos > queryPos) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ const int blockIdxKV = kv / blockSize;
|
|
|
+ const int blockOff = kv % blockSize;
|
|
|
+ const float* kBlock = KBlocks[blockIdxKV];
|
|
|
+ const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ for (int d = 0; d < headDim; d++) {
|
|
|
+ dot += q[d] * k[d];
|
|
|
+ }
|
|
|
+ float score = dot * scale;
|
|
|
+ localSum += expf(score - maxVal);
|
|
|
+ }
|
|
|
+
|
|
|
+ __shared__ float ssum[256];
|
|
|
+ ssum[threadIdx.x] = localSum;
|
|
|
+ __syncthreads();
|
|
|
+ for (int s = blockDim.x / 2; s > 0; s >>= 1) {
|
|
|
+ if (threadIdx.x < s) {
|
|
|
+ ssum[threadIdx.x] += ssum[threadIdx.x + s];
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ }
|
|
|
+ float sumVal = ssum[0];
|
|
|
+ float invSum = (sumVal > 0.0f) ? (1.0f / sumVal) : 0.0f;
|
|
|
+
|
|
|
+ // Pass 3: output accumulation
|
|
|
+ for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
|
|
|
+ float outVal = 0.0f;
|
|
|
+ for (int kv = 0; kv < kvLen; kv++) {
|
|
|
+ int keyPos = kv;
|
|
|
+ if (keyPos > queryPos) {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ const int blockIdxKV = kv / blockSize;
|
|
|
+ const int blockOff = kv % blockSize;
|
|
|
+ const float* kBlock = KBlocks[blockIdxKV];
|
|
|
+ const float* k = kBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float dot = 0.0f;
|
|
|
+ for (int kd = 0; kd < headDim; kd++) {
|
|
|
+ dot += q[kd] * k[kd];
|
|
|
+ }
|
|
|
+ float score = dot * scale;
|
|
|
+ float w = expf(score - maxVal) * invSum;
|
|
|
+
|
|
|
+ const float* vBlock = VBlocks[blockIdxKV];
|
|
|
+ const float* v = vBlock + blockOff * kvStride + kvHead * headDim;
|
|
|
+ outVal += w * v[d];
|
|
|
+ }
|
|
|
+ o[d] = outVal;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+__global__ void paged_attention_batch_kernel(
|
|
|
+ const float* Q,
|
|
|
+ const float* const* KBlocksFlat,
|
|
|
+ const float* const* VBlocksFlat,
|
|
|
+ const int* blockOffsets,
|
|
|
+ const int* kvLens,
|
|
|
+ const int* queryPos,
|
|
|
+ float* out,
|
|
|
+ int numTokens,
|
|
|
+ int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale
|
|
|
+) {
|
|
|
+ int tok = blockIdx.x;
|
|
|
+ int head = blockIdx.y;
|
|
|
+ if (tok >= numTokens) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ int kvHead = head / (numHeads / numKVHeads);
|
|
|
+
|
|
|
+ const float* q = Q + tok * numHeads * headDim + head * headDim;
|
|
|
+ float* o = out + tok * numHeads * headDim + head * headDim;
|
|
|
+
|
|
|
+ int kvLen = kvLens[tok];
|
|
|
+ int qPos = queryPos[tok];
|
|
|
+ int base = blockOffsets[tok];
|
|
|
+
|
|
|
+ const int kvStride = numKVHeads * headDim;
|
|
|
+ const int effectiveLen = (kvLen < (qPos + 1)) ? kvLen : (qPos + 1);
|
|
|
+
|
|
|
+ float acc = 0.0f;
|
|
|
+ if (threadIdx.x >= headDim) {
|
|
|
+ acc = 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
+ __shared__ float m;
|
|
|
+ __shared__ float l;
|
|
|
+ __shared__ float alpha;
|
|
|
+ __shared__ float beta;
|
|
|
+ __shared__ float dotShared;
|
|
|
+
|
|
|
+ if (threadIdx.x == 0) {
|
|
|
+ m = -INFINITY;
|
|
|
+ l = 0.0f;
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ for (int kv = 0; kv < effectiveLen; kv++) {
|
|
|
+ const int bidx = kv / blockSize;
|
|
|
+ const int boff = kv % blockSize;
|
|
|
+
|
|
|
+ const float* kBlock = KBlocksFlat[base + bidx];
|
|
|
+ const float* k = kBlock + boff * kvStride + kvHead * headDim;
|
|
|
+
|
|
|
+ float partial = 0.0f;
|
|
|
+ for (int d = threadIdx.x; d < headDim; d += blockDim.x) {
|
|
|
+ partial = fmaf(q[d], k[d], partial);
|
|
|
+ }
|
|
|
+
|
|
|
+ // block reduction (sum)
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ partial += __shfl_down_sync(0xffffffff, partial, offset);
|
|
|
+ }
|
|
|
+ __shared__ float warpSum[8];
|
|
|
+ int lane = threadIdx.x & 31;
|
|
|
+ int warp = threadIdx.x >> 5;
|
|
|
+ if (lane == 0) {
|
|
|
+ warpSum[warp] = partial;
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ if (warp == 0) {
|
|
|
+ float v = (lane < 8) ? warpSum[lane] : 0.0f;
|
|
|
+ for (int offset = 16; offset > 0; offset >>= 1) {
|
|
|
+ v += __shfl_down_sync(0xffffffff, v, offset);
|
|
|
+ }
|
|
|
+ if (lane == 0) {
|
|
|
+ dotShared = v;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ float score = dotShared * scale;
|
|
|
+
|
|
|
+ if (threadIdx.x == 0) {
|
|
|
+ float newM = fmaxf(m, score);
|
|
|
+ float a = expf(m - newM);
|
|
|
+ float b = expf(score - newM);
|
|
|
+ m = newM;
|
|
|
+ l = l * a + b;
|
|
|
+ alpha = a;
|
|
|
+ beta = b;
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+
|
|
|
+ if (threadIdx.x < headDim) {
|
|
|
+ const float* vBlock = VBlocksFlat[base + bidx];
|
|
|
+ const float* v = vBlock + boff * kvStride + kvHead * headDim;
|
|
|
+ acc = fmaf(beta, v[threadIdx.x], acc * alpha);
|
|
|
+ }
|
|
|
+ __syncthreads();
|
|
|
+ }
|
|
|
+
|
|
|
+ if (threadIdx.x < headDim) {
|
|
|
+ float invL = (l > 0.0f) ? (1.0f / l) : 0.0f;
|
|
|
+ o[threadIdx.x] = acc * invL;
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_attention_f32(
|
|
|
+ const float* Q, const float* K, const float* V, float* out,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ float scale, int startPos
|
|
|
+) {
|
|
|
+ dim3 blocks(seqLen, numHeads);
|
|
|
+ int threads = 256;
|
|
|
+ size_t sharedMem = kvLen * sizeof(float);
|
|
|
+
|
|
|
+ attention_kernel<<<blocks, threads, sharedMem>>>(
|
|
|
+ Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_paged_attention_f32(
|
|
|
+ const float* Q,
|
|
|
+ const float* const* KBlocks,
|
|
|
+ const float* const* VBlocks,
|
|
|
+ float* out,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ int blockSize,
|
|
|
+ float scale, int startPos
|
|
|
+) {
|
|
|
+ // Split-KV Flash Decoding for long contexts.
|
|
|
+ const int maxEffectiveLen = (kvLen < (startPos + seqLen)) ? kvLen : (startPos + seqLen);
|
|
|
+ const int numSplits = (maxEffectiveLen + kPagedAttentionSplitSize - 1) / kPagedAttentionSplitSize;
|
|
|
+ const int qhCount = seqLen * numHeads;
|
|
|
+ const bool useSplit = (headDim <= 256) && (numSplits > 1) && (qhCount < kPagedAttentionSplitQHThreshold);
|
|
|
+
|
|
|
+ if (!useSplit) {
|
|
|
+ dim3 blocks(seqLen, numHeads);
|
|
|
+ int threads = 256;
|
|
|
+ paged_attention_kernel<<<blocks, threads>>>(
|
|
|
+ Q, KBlocks, VBlocks, out,
|
|
|
+ seqLen, kvLen, numHeads, numKVHeads, headDim,
|
|
|
+ blockSize, scale, startPos
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ const size_t splitCount = (size_t)qhCount * (size_t)numSplits;
|
|
|
+ const size_t totalFloats = splitCount * (size_t)(headDim + 2);
|
|
|
+ float* buf = reinterpret_cast<float*>(cuda_malloc(totalFloats * sizeof(float)));
|
|
|
+ if (buf == nullptr) {
|
|
|
+ return 1;
|
|
|
+ }
|
|
|
+ float* partialMax = buf;
|
|
|
+ float* partialSum = partialMax + splitCount;
|
|
|
+ float* partialOut = partialSum + splitCount;
|
|
|
+
|
|
|
+ dim3 blocks1(seqLen, numHeads, numSplits);
|
|
|
+ paged_attention_split_kv_kernel<float><<<blocks1, 32>>>(
|
|
|
+ Q, KBlocks, VBlocks,
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ seqLen, kvLen, numHeads, numKVHeads, headDim,
|
|
|
+ blockSize,
|
|
|
+ scale, startPos,
|
|
|
+ numSplits, kPagedAttentionSplitSize
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ dim3 blocks2(seqLen, numHeads);
|
|
|
+ paged_attention_split_kv_reduce_kernel<<<blocks2, 32>>>(
|
|
|
+ partialMax, partialSum, partialOut,
|
|
|
+ out,
|
|
|
+ seqLen, numHeads, headDim, numSplits
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+
|
|
|
+ cuda_free(buf);
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+int cuda_attention_f32_timed(
|
|
|
+ const float* Q, const float* K, const float* V, float* out,
|
|
|
+ int seqLen, int kvLen, int numHeads, int numKVHeads, int headDim,
|
|
|
+ float scale, int startPos, float* ms
|
|
|
+) {
|
|
|
+ cudaEvent_t evStart;
|
|
|
+ cudaEvent_t evStop;
|
|
|
+ CHECK_CUDA(cudaEventCreate(&evStart));
|
|
|
+ CHECK_CUDA(cudaEventCreate(&evStop));
|
|
|
+
|
|
|
+ dim3 blocks(seqLen, numHeads);
|
|
|
+ int threads = 256;
|
|
|
+ size_t sharedMem = kvLen * sizeof(float);
|
|
|
+
|
|
|
+ CHECK_CUDA(cudaEventRecord(evStart));
|
|
|
+ attention_kernel<<<blocks, threads, sharedMem>>>(
|
|
|
+ Q, K, V, out, seqLen, kvLen, numHeads, numKVHeads, headDim, scale, startPos
|
|
|
+ );
|
|
|
+ CHECK_CUDA(cudaEventRecord(evStop));
|
|
|
+ CHECK_CUDA(cudaEventSynchronize(evStop));
|
|
|
+
|
|
|
+ float elapsed = 0.0f;
|
|
|
+ CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
|
|
|
+ if (ms != NULL) {
|
|
|
+ *ms = elapsed;
|
|
|
+ }
|
|
|
+
|
|
|
+ CHECK_CUDA(cudaEventDestroy(evStart));
|
|
|
+ CHECK_CUDA(cudaEventDestroy(evStop));
|
|
|
+ CHECK_CUDA(cudaGetLastError());
|
|
|
+ return 0;
|
|
|
+}
|
|
|
+
|
|
|
+// Debug helper
|
|
|
+int cuda_print_struct_sizes() {
|
|
|
+ printf("GPU Struct Sizes:\n");
|
|
|
+ printf("BlockQ2_K: %lu\n", sizeof(BlockQ2_K));
|
|
|
+ printf("BlockQ3_K: %lu\n", sizeof(BlockQ3_K));
|
|
|
+ printf("BlockQ4_K: %lu\n", sizeof(BlockQ4_K));
|
|
|
+ printf("BlockQ6_K: %lu\n", sizeof(BlockQ6_K));
|
|
|
+ printf("BlockQ8_K: %lu\n", sizeof(BlockQ8_K));
|
|
|
+ return 0;
|
|
|
+}
|