#include "cuda_common.cuh" #include namespace { // Simple tiled GEMM kernels for correctness-first dense matmul. // These are used as the default dense GEMM path when CUTLASS is not built. constexpr int TILE = 16; __global__ void matmul_f32_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C, int M, int K, int N) { __shared__ float As[TILE][TILE]; __shared__ float Bs[TILE][TILE]; const int row = blockIdx.y * TILE + threadIdx.y; const int col = blockIdx.x * TILE + threadIdx.x; float acc = 0.0f; for (int t = 0; t < (K + TILE - 1) / TILE; ++t) { const int aCol = t * TILE + threadIdx.x; const int bRow = t * TILE + threadIdx.y; As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0f; Bs[threadIdx.y][threadIdx.x] = (bRow < K && col < N) ? B[bRow * N + col] : 0.0f; __syncthreads(); #pragma unroll for (int i = 0; i < TILE; ++i) { acc += As[threadIdx.y][i] * Bs[i][threadIdx.x]; } __syncthreads(); } if (row < M && col < N) { C[row * N + col] = acc; } } // Computes C = A @ B^T where B is stored row-major [N, K]. __global__ void matmul_f32_nt_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C, int M, int K, int N) { __shared__ float As[TILE][TILE]; __shared__ float Bs[TILE][TILE]; const int row = blockIdx.y * TILE + threadIdx.y; const int col = blockIdx.x * TILE + threadIdx.x; // maps to n float acc = 0.0f; for (int t = 0; t < (K + TILE - 1) / TILE; ++t) { const int aCol = t * TILE + threadIdx.x; const int bCol = t * TILE + threadIdx.y; // k index for B row As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0f; Bs[threadIdx.y][threadIdx.x] = (col < N && bCol < K) ? B[col * K + bCol] : 0.0f; __syncthreads(); #pragma unroll for (int i = 0; i < TILE; ++i) { acc += As[threadIdx.y][i] * Bs[i][threadIdx.x]; } __syncthreads(); } if (row < M && col < N) { C[row * N + col] = acc; } } // Computes C = A @ B^T where A and B are stored as IEEE half in uint16. __global__ void matmul_f16_nt_kernel(const __half* __restrict__ A, const __half* __restrict__ B, float* __restrict__ C, int M, int K, int N) { __shared__ __half As[TILE][TILE]; __shared__ __half Bs[TILE][TILE]; const int row = blockIdx.y * TILE + threadIdx.y; const int col = blockIdx.x * TILE + threadIdx.x; float acc = 0.0f; for (int t = 0; t < (K + TILE - 1) / TILE; ++t) { const int aCol = t * TILE + threadIdx.x; const int bCol = t * TILE + threadIdx.y; As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : __float2half(0.0f); Bs[threadIdx.y][threadIdx.x] = (col < N && bCol < K) ? B[col * K + bCol] : __float2half(0.0f); __syncthreads(); #pragma unroll for (int i = 0; i < TILE; ++i) { acc += __half2float(As[threadIdx.y][i]) * __half2float(Bs[i][threadIdx.x]); } __syncthreads(); } if (row < M && col < N) { C[row * N + col] = acc; } } } // namespace __global__ void matmul_q5k_kernel(float* A, const BlockQ5_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; // row is uniform across the block, so an early return here is safe. if (row >= M) return; // col is warp-specific. Do NOT early-return on col>=N because we use __syncthreads(). const bool colIn = (col < N); float sum = 0.0f; // Cache the A tile (256 floats) once per block so the 8 warps (8 columns) reuse it. __shared__ float a_sh[256]; __shared__ unsigned char sc_sh[8][8]; __shared__ unsigned char m_sh[8][8]; __shared__ float ds_sh[8][8]; __shared__ float dm_sh[8][8]; __shared__ float d_sh[8]; __shared__ float dmin_sh[8]; for (int blk = 0; blk < blocksPerRow; blk++) { // Cache A tile once per block (256 floats). Each thread loads one element. const int tid = warp * 32 + lane; const float* aRow = A + row * K + blk * 256; a_sh[tid] = aRow[tid]; if (colIn) { const BlockQ5_K* b = &B[col * blocksPerRow + blk]; if (lane < 8) { unsigned char sc; unsigned char mn; if (lane < 4) { sc = b->scales[lane] & 63; mn = b->scales[lane + 4] & 63; } else { const int j = lane; sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4); mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4); } sc_sh[warp][lane] = sc; m_sh[warp][lane] = mn; } if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); dmin_sh[warp] = fp16_to_fp32(b->dmin); } } // Ensure all warps have finished loading this block's a_sh before use, // and that no warp overwrites a_sh while others are still reading it. __syncthreads(); if (colIn) { const BlockQ5_K* b = &B[col * blocksPerRow + blk]; // Precompute per-group multipliers once (one lane per group). if (lane < 8) { const float d = d_sh[warp]; const float dmin = dmin_sh[warp]; const unsigned char sc = sc_sh[warp][lane]; const unsigned char mn = m_sh[warp][lane]; ds_sh[warp][lane] = d * (float)sc; dm_sh[warp][lane] = dmin * (float)mn; } __syncwarp(); const unsigned char hb = b->qh[lane]; #pragma unroll for (int p = 0; p < 4; p++) { const unsigned char qs = b->qs[p * 32 + lane]; int q0 = qs & 0xF; int q1 = qs >> 4; q0 += ((hb >> (2 * p)) & 1) << 4; q1 += ((hb >> (2 * p + 1)) & 1) << 4; const int idx0 = p * 64 + lane; const int idx1 = idx0 + 32; const int g0 = 2 * p; const int g1 = g0 + 1; const float ds0 = ds_sh[warp][g0]; const float dm0 = dm_sh[warp][g0]; const float ds1 = ds_sh[warp][g1]; const float dm1 = dm_sh[warp][g1]; sum += a_sh[idx0] * ((float)q0 * ds0 - dm0); sum += a_sh[idx1] * ((float)q1 * ds1 - dm1); } } __syncthreads(); } for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f32_q5k(float* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q5k_kernel<<>>(A, (const BlockQ5_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } int cuda_matmul_f32(float* A, float* B, float* C, int M, int K, int N) { if (M <= 0 || N <= 0 || K <= 0) return 0; dim3 threads(TILE, TILE); dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE); matmul_f32_kernel<<>>(A, B, C, M, K, N); CHECK_CUDA(cudaGetLastError()); return 0; } int cuda_matmul_f32_nt(float* A, float* B, float* C, int M, int K, int N) { if (M <= 0 || N <= 0 || K <= 0) return 0; dim3 threads(TILE, TILE); dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE); matmul_f32_nt_kernel<<>>(A, B, C, M, K, N); CHECK_CUDA(cudaGetLastError()); return 0; } int cuda_matmul_f16_nt(const unsigned short* A, const unsigned short* B, float* C, int M, int K, int N) { if (M <= 0 || N <= 0 || K <= 0) return 0; dim3 threads(TILE, TILE); dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE); matmul_f16_nt_kernel<<>>(reinterpret_cast(A), reinterpret_cast(B), C, M, K, N); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // Fused Q8_K MatMul Kernel (tiled) // C[m,n] = sum_k A[m,k] * dequant(B[n,k]) // Uses shared memory tiles to reduce global memory pressure. // ============================================================ __global__ void matmul_q8k_kernel(float* A, const BlockQ8_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ float d_sh[8]; for (int blk = 0; blk < blocksPerRow; blk++) { const int tid = warp * 32 + lane; const float* aRow = A + row * K + blk * 256; a_sh[tid] = aRow[tid]; if (colIn && lane == 0) { d_sh[warp] = B[col * blocksPerRow + blk].d; } __syncthreads(); if (colIn) { const BlockQ8_K* b = &B[col * blocksPerRow + blk]; const float d = d_sh[warp]; // Each lane handles 8 weights in the 256-wide block. #pragma unroll for (int i = 0; i < 8; i++) { const int idx = lane + (i * 32); // 0..255 const float w = d * (float)((int)b->qs[idx]); sum += a_sh[idx] * w; } } __syncthreads(); } // Warp reduction for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f32_q8k(float* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q8k_kernel<<>>(A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } int cuda_matmul_f32_q8k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms) { cudaEvent_t evStart; cudaEvent_t evStop; CHECK_CUDA(cudaEventCreate(&evStart)); CHECK_CUDA(cudaEventCreate(&evStop)); int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); CHECK_CUDA(cudaEventRecord(evStart)); matmul_q8k_kernel<<>>(A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaEventRecord(evStop)); CHECK_CUDA(cudaEventSynchronize(evStop)); float elapsed = 0.0f; CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop)); if (ms != NULL) { *ms = elapsed; } CHECK_CUDA(cudaEventDestroy(evStart)); CHECK_CUDA(cudaEventDestroy(evStop)); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // FP16 Input Variants - 2x memory bandwidth for activations // Input A is FP16, dequantized weights computed in FP32, // accumulation in FP32, output FP32. // ============================================================ __global__ void matmul_q8k_kernel_f16in(const __half* A, const BlockQ8_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ float d_sh[8]; for (int blk = 0; blk < blocksPerRow; blk++) { const int tid = warp * 32 + lane; const __half* aRow = A + row * K + blk * 256; // Load FP16, convert to FP32 in shared memory a_sh[tid] = __half2float(aRow[tid]); if (colIn && lane == 0) { d_sh[warp] = B[col * blocksPerRow + blk].d; } __syncthreads(); if (colIn) { const BlockQ8_K* b = &B[col * blocksPerRow + blk]; const float d = d_sh[warp]; #pragma unroll for (int i = 0; i < 8; i++) { const int idx = lane + (i * 32); const float w = d * (float)((int)b->qs[idx]); sum += a_sh[idx] * w; } } __syncthreads(); } for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f16_q8k(const void* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q8k_kernel_f16in<<>>((const __half*)A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // Fused Q4_K MatMul Kernel - simplified version // For full performance, would need shared memory tiling // ============================================================ __global__ void matmul_q4k_kernel(float* A, const BlockQ4_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ unsigned char sc_sh[8][8]; __shared__ unsigned char m_sh[8][8]; __shared__ float ds_sh[8][8]; __shared__ float dm_sh[8][8]; __shared__ float d_sh[8]; __shared__ float dmin_sh[8]; for (int blk = 0; blk < blocksPerRow; blk++) { const int tid = warp * 32 + lane; const float* aRow = A + row * K + blk * 256; a_sh[tid] = aRow[tid]; if (colIn) { const BlockQ4_K* b = &B[col * blocksPerRow + blk]; // Parallel unpack scale/min for groups 0..7 (one lane per group). if (lane < 8) { unsigned char sc; unsigned char mn; if (lane < 4) { sc = b->scales[lane] & 63; mn = b->scales[lane + 4] & 63; } else { const int j = lane; sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4); mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4); } sc_sh[warp][lane] = sc; m_sh[warp][lane] = mn; } if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); dmin_sh[warp] = fp16_to_fp32(b->dmin); } } __syncthreads(); if (colIn) { const BlockQ4_K* b = &B[col * blocksPerRow + blk]; // Precompute per-group float multipliers once. if (lane < 8) { const float d = d_sh[warp]; const float dmin = dmin_sh[warp]; const unsigned char sc = sc_sh[warp][lane]; const unsigned char mn = m_sh[warp][lane]; ds_sh[warp][lane] = d * (float)sc; dm_sh[warp][lane] = dmin * (float)mn; } __syncwarp(); const float ds0 = ds_sh[warp][0]; const float dm0 = dm_sh[warp][0]; const float ds1 = ds_sh[warp][1]; const float dm1 = dm_sh[warp][1]; const float ds2 = ds_sh[warp][2]; const float dm2 = dm_sh[warp][2]; const float ds3 = ds_sh[warp][3]; const float dm3 = dm_sh[warp][3]; const float ds4 = ds_sh[warp][4]; const float dm4 = dm_sh[warp][4]; const float ds5 = ds_sh[warp][5]; const float dm5 = dm_sh[warp][5]; const float ds6 = ds_sh[warp][6]; const float dm6 = dm_sh[warp][6]; const float ds7 = ds_sh[warp][7]; const float dm7 = dm_sh[warp][7]; // Each lane processes 4 bytes; each byte contains 2 nibbles => 8 values per lane. // This halves qs loads and reduces bit ops. #pragma unroll for (int p = 0; p < 4; p++) { const unsigned char qs = b->qs[p * 32 + lane]; const int q0 = qs & 0xF; const int q1 = qs >> 4; const int idx0 = p * 64 + lane; // group = 2*p const int idx1 = idx0 + 32; // group = 2*p + 1 float dsA, dmA, dsB, dmB; if (p == 0) { dsA = ds0; dmA = dm0; dsB = ds1; dmB = dm1; } else if (p == 1) { dsA = ds2; dmA = dm2; dsB = ds3; dmB = dm3; } else if (p == 2) { dsA = ds4; dmA = dm4; dsB = ds5; dmB = dm5; } else { dsA = ds6; dmA = dm6; dsB = ds7; dmB = dm7; } const float w0 = (float)q0 * dsA - dmA; const float w1 = (float)q1 * dsB - dmB; sum += a_sh[idx0] * w0; sum += a_sh[idx1] * w1; } } __syncthreads(); } // Warp reduction for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f32_q4k(float* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q4k_kernel<<>>(A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } int cuda_matmul_f32_q4k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms) { cudaEvent_t evStart; cudaEvent_t evStop; CHECK_CUDA(cudaEventCreate(&evStart)); CHECK_CUDA(cudaEventCreate(&evStop)); int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); CHECK_CUDA(cudaEventRecord(evStart)); matmul_q4k_kernel<<>>(A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaEventRecord(evStop)); CHECK_CUDA(cudaEventSynchronize(evStop)); float elapsed = 0.0f; CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop)); if (ms != NULL) { *ms = elapsed; } CHECK_CUDA(cudaEventDestroy(evStart)); CHECK_CUDA(cudaEventDestroy(evStop)); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // Fused Q2_K MatMul Kernel - Naive // ============================================================ __global__ void matmul_q2k_kernel(float* A, const BlockQ2_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ float d_sh[8]; __shared__ float dmin_sh[8]; __shared__ unsigned char scales_sh[8][16]; __shared__ unsigned char qs_sh[8][64]; for (int blk = 0; blk < blocksPerRow; blk++) { // Cache A tile once per block (256 floats) to avoid redundant global loads. // Each thread loads one element: tid in [0,255]. const int tid = warp * 32 + lane; const float* aRow = A + row * K + blk * 256; a_sh[tid] = aRow[tid]; if (colIn) { const BlockQ2_K* b = &B[col * blocksPerRow + blk]; // Cooperative per-warp cache. if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); dmin_sh[warp] = fp16_to_fp32(b->dmin); } if (lane < 16) { scales_sh[warp][lane] = b->scales[lane]; } // Load 64 bytes qs with 32 lanes. qs_sh[warp][lane] = b->qs[lane]; qs_sh[warp][lane + 32] = b->qs[lane + 32]; } __syncthreads(); if (colIn) { const float d = d_sh[warp]; const float dmin = dmin_sh[warp]; // Each lane handles 8 values. #pragma unroll for (int i = 0; i < 8; i++) { const int idx = lane + (i * 32); // 0..255 const int is = idx >> 5; // 0..7 const int iq = idx & 31; // 0..31 const int qsIdx = (is >> 2) * 32 + iq; const int shift = (is & 3) * 2; const int val = (qs_sh[warp][qsIdx] >> shift) & 3; const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4); const unsigned char sc = scales_sh[warp][scIdx]; const float dl = d * (float)(sc & 0xF); const float ml = dmin * (float)(sc >> 4); const float w = dl * (float)val - ml; sum += a_sh[idx] * w; } } // Ensure all warps finished reading this block's a_sh before it is overwritten. __syncthreads(); } // Warp reduction for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f32_q2k(float* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q2k_kernel<<>>(A, (const BlockQ2_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // Fused Q3_K MatMul Kernel // ============================================================ __global__ void matmul_q3k_kernel(float* A, const BlockQ3_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ float d_sh[8]; __shared__ unsigned char scales_sh[8][12]; __shared__ unsigned char qs_sh[8][64]; __shared__ unsigned char hmask_sh[8][32]; for (int blk = 0; blk < blocksPerRow; blk++) { // Cache A tile once per block (256 floats). Each thread loads one element. const int tid = warp * 32 + lane; const float* aRow = A + row * K + blk * 256; a_sh[tid] = aRow[tid]; if (colIn) { const BlockQ3_K* b = &B[col * blocksPerRow + blk]; // Cache quant block bytes. if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); } if (lane < 12) { scales_sh[warp][lane] = b->scales[lane]; } // qs: 64 bytes qs_sh[warp][lane] = b->qs[lane]; qs_sh[warp][lane + 32] = b->qs[lane + 32]; // hmask: 32 bytes hmask_sh[warp][lane] = b->hmask[lane]; } __syncthreads(); if (colIn) { const float d = d_sh[warp]; // Each lane handles 8 elements. #pragma unroll for (int i = 0; i < 8; i++) { const int idx = lane + (i * 32); // 0..255 const int is = idx >> 5; // 0..7 const int iq = idx & 31; // 0..31 const int qsIdx = (is >> 2) * 32 + iq; const int shift = (is & 3) * 2; int qv = (qs_sh[warp][qsIdx] >> shift) & 0x3; const unsigned char m = (unsigned char)(1 << ((is >> 2) * 4 + (is & 3))); if ((hmask_sh[warp][iq] & m) == 0) { qv -= 4; } const int sIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4); // 0..15 unsigned char sc; if (sIdx < 8) { sc = scales_sh[warp][sIdx] & 0xF; } else { sc = scales_sh[warp][sIdx - 8] >> 4; } sc |= ((scales_sh[warp][8 + (sIdx & 3)] >> (2 * (sIdx >> 2))) & 0x3) << 4; const float scale = (float)((int)((signed char)sc) - 32); const float w = d * scale * (float)qv; sum += a_sh[idx] * w; } } __syncthreads(); } // Warp reduction for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f32_q3k(float* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q3k_kernel<<>>(A, (const BlockQ3_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // Fused Q6_K MatMul Kernel // ============================================================ __global__ void matmul_q6k_kernel(float* A, const BlockQ6_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ float d_sh[8]; __shared__ signed char scales_sh[8][16]; __shared__ unsigned char ql_sh[8][128]; __shared__ unsigned char qh_sh[8][64]; for (int blk = 0; blk < blocksPerRow; blk++) { // Cache A tile once per block. const int tid = warp * 32 + lane; const float* aRow = A + row * K + blk * 256; a_sh[tid] = aRow[tid]; if (colIn) { const BlockQ6_K* b = &B[col * blocksPerRow + blk]; // Cache quant block bytes. if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); } if (lane < 16) { scales_sh[warp][lane] = b->scales[lane]; } // qh: 64 bytes qh_sh[warp][lane] = b->qh[lane]; qh_sh[warp][lane + 32] = b->qh[lane + 32]; // ql: 128 bytes ql_sh[warp][lane] = b->ql[lane]; ql_sh[warp][lane + 32] = b->ql[lane + 32]; ql_sh[warp][lane + 64] = b->ql[lane + 64]; ql_sh[warp][lane + 96] = b->ql[lane + 96]; } __syncthreads(); if (colIn) { const float d = d_sh[warp]; // Each lane handles 8 elements. #pragma unroll for (int i = 0; i < 8; i++) { const int idx = lane + (i * 32); // 0..255 const int is = idx >> 5; // 0..7 const int iq = idx & 31; // 0..31 const int qlIdx = (is >> 2) * 64 + (is & 1) * 32 + iq; const int qhIdx = (is >> 2) * 32 + iq; const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4); const unsigned char ql = ql_sh[warp][qlIdx]; const unsigned char qh = qh_sh[warp][qhIdx]; const int shift_ql = ((is & 3) < 2) ? 0 : 4; const int shift_qh = (is & 3) * 2; int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4); q -= 32; const float w = d * (float)scales_sh[warp][scIdx] * (float)q; sum += a_sh[idx] * w; } } __syncthreads(); } // Warp reduction for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f32_q6k(float* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q6k_kernel<<>>(A, (const BlockQ6_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } // ============================================================ // FP16 Input Variants for Q4K, Q5K, Q2K, Q3K, Q6K // Same logic as FP32 versions but load A as FP16 // ============================================================ __global__ void matmul_q4k_kernel_f16in(const __half* A, const BlockQ4_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ unsigned char sc_sh[8][8]; __shared__ unsigned char m_sh[8][8]; __shared__ float ds_sh[8][8]; __shared__ float dm_sh[8][8]; __shared__ float d_sh[8]; __shared__ float dmin_sh[8]; for (int blk = 0; blk < blocksPerRow; blk++) { const int tid = warp * 32 + lane; const __half* aRow = A + row * K + blk * 256; a_sh[tid] = __half2float(aRow[tid]); if (colIn) { const BlockQ4_K* b = &B[col * blocksPerRow + blk]; if (lane < 8) { unsigned char sc, mn; if (lane < 4) { sc = b->scales[lane] & 63; mn = b->scales[lane + 4] & 63; } else { const int j = lane; sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4); mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4); } sc_sh[warp][lane] = sc; m_sh[warp][lane] = mn; } if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); dmin_sh[warp] = fp16_to_fp32(b->dmin); } } __syncthreads(); if (colIn) { const BlockQ4_K* b = &B[col * blocksPerRow + blk]; if (lane < 8) { ds_sh[warp][lane] = d_sh[warp] * (float)sc_sh[warp][lane]; dm_sh[warp][lane] = dmin_sh[warp] * (float)m_sh[warp][lane]; } __syncwarp(); #pragma unroll for (int p = 0; p < 4; p++) { const unsigned char qs = b->qs[p * 32 + lane]; const int q0 = qs & 0xF; const int q1 = qs >> 4; const int idx0 = p * 64 + lane; const int idx1 = idx0 + 32; const int g0 = 2 * p; const int g1 = g0 + 1; sum += a_sh[idx0] * ((float)q0 * ds_sh[warp][g0] - dm_sh[warp][g0]); sum += a_sh[idx1] * ((float)q1 * ds_sh[warp][g1] - dm_sh[warp][g1]); } } __syncthreads(); } for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f16_q4k(const void* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q4k_kernel_f16in<<>>((const __half*)A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } __global__ void matmul_q5k_kernel_f16in(const __half* A, const BlockQ5_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ unsigned char sc_sh[8][8]; __shared__ unsigned char m_sh[8][8]; __shared__ float ds_sh[8][8]; __shared__ float dm_sh[8][8]; __shared__ float d_sh[8]; __shared__ float dmin_sh[8]; for (int blk = 0; blk < blocksPerRow; blk++) { const int tid = warp * 32 + lane; const __half* aRow = A + row * K + blk * 256; a_sh[tid] = __half2float(aRow[tid]); if (colIn) { const BlockQ5_K* b = &B[col * blocksPerRow + blk]; if (lane < 8) { unsigned char sc, mn; if (lane < 4) { sc = b->scales[lane] & 63; mn = b->scales[lane + 4] & 63; } else { const int j = lane; sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4); mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4); } sc_sh[warp][lane] = sc; m_sh[warp][lane] = mn; } if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); dmin_sh[warp] = fp16_to_fp32(b->dmin); } } __syncthreads(); if (colIn) { const BlockQ5_K* b = &B[col * blocksPerRow + blk]; if (lane < 8) { ds_sh[warp][lane] = d_sh[warp] * (float)sc_sh[warp][lane]; dm_sh[warp][lane] = dmin_sh[warp] * (float)m_sh[warp][lane]; } __syncwarp(); const unsigned char hb = b->qh[lane]; #pragma unroll for (int p = 0; p < 4; p++) { const unsigned char qs = b->qs[p * 32 + lane]; int q0 = qs & 0xF; int q1 = qs >> 4; q0 += ((hb >> (2 * p)) & 1) << 4; q1 += ((hb >> (2 * p + 1)) & 1) << 4; const int idx0 = p * 64 + lane; const int idx1 = idx0 + 32; const int g0 = 2 * p; const int g1 = g0 + 1; sum += a_sh[idx0] * ((float)q0 * ds_sh[warp][g0] - dm_sh[warp][g0]); sum += a_sh[idx1] * ((float)q1 * ds_sh[warp][g1] - dm_sh[warp][g1]); } } __syncthreads(); } for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f16_q5k(const void* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q5k_kernel_f16in<<>>((const __half*)A, (const BlockQ5_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } __global__ void matmul_q2k_kernel_f16in(const __half* A, const BlockQ2_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ float d_sh[8]; __shared__ float dmin_sh[8]; __shared__ unsigned char scales_sh[8][16]; __shared__ unsigned char qs_sh[8][64]; for (int blk = 0; blk < blocksPerRow; blk++) { const int tid = warp * 32 + lane; const __half* aRow = A + row * K + blk * 256; a_sh[tid] = __half2float(aRow[tid]); if (colIn) { const BlockQ2_K* b = &B[col * blocksPerRow + blk]; if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); dmin_sh[warp] = fp16_to_fp32(b->dmin); } if (lane < 16) { scales_sh[warp][lane] = b->scales[lane]; } qs_sh[warp][lane] = b->qs[lane]; qs_sh[warp][lane + 32] = b->qs[lane + 32]; } __syncthreads(); if (colIn) { const float d = d_sh[warp]; const float dmin = dmin_sh[warp]; #pragma unroll for (int i = 0; i < 8; i++) { const int idx = lane + (i * 32); const int is = idx >> 5; const int iq = idx & 31; const int qsIdx = (is >> 2) * 32 + iq; const int shift = (is & 3) * 2; const int val = (qs_sh[warp][qsIdx] >> shift) & 3; const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4); const unsigned char sc = scales_sh[warp][scIdx]; const float dl = d * (float)(sc & 0xF); const float ml = dmin * (float)(sc >> 4); sum += a_sh[idx] * (dl * (float)val - ml); } } __syncthreads(); } for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f16_q2k(const void* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q2k_kernel_f16in<<>>((const __half*)A, (const BlockQ2_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } __global__ void matmul_q3k_kernel_f16in(const __half* A, const BlockQ3_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ float d_sh[8]; __shared__ unsigned char scales_sh[8][12]; __shared__ unsigned char qs_sh[8][64]; __shared__ unsigned char hmask_sh[8][32]; for (int blk = 0; blk < blocksPerRow; blk++) { const int tid = warp * 32 + lane; const __half* aRow = A + row * K + blk * 256; a_sh[tid] = __half2float(aRow[tid]); if (colIn) { const BlockQ3_K* b = &B[col * blocksPerRow + blk]; if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); } if (lane < 12) { scales_sh[warp][lane] = b->scales[lane]; } qs_sh[warp][lane] = b->qs[lane]; qs_sh[warp][lane + 32] = b->qs[lane + 32]; hmask_sh[warp][lane] = b->hmask[lane]; } __syncthreads(); if (colIn) { const float d = d_sh[warp]; #pragma unroll for (int i = 0; i < 8; i++) { const int idx = lane + (i * 32); const int is = idx >> 5; const int iq = idx & 31; const int qsIdx = (is >> 2) * 32 + iq; const int shift = (is & 3) * 2; int qv = (qs_sh[warp][qsIdx] >> shift) & 0x3; const unsigned char m = (unsigned char)(1 << ((is >> 2) * 4 + (is & 3))); if ((hmask_sh[warp][iq] & m) == 0) { qv -= 4; } const int sIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4); unsigned char sc; if (sIdx < 8) { sc = scales_sh[warp][sIdx] & 0xF; } else { sc = scales_sh[warp][sIdx - 8] >> 4; } sc |= ((scales_sh[warp][8 + (sIdx & 3)] >> (2 * (sIdx >> 2))) & 0x3) << 4; const float scale = (float)((int)((signed char)sc) - 32); sum += a_sh[idx] * (d * scale * (float)qv); } } __syncthreads(); } for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f16_q3k(const void* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q3k_kernel_f16in<<>>((const __half*)A, (const BlockQ3_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; } __global__ void matmul_q6k_kernel_f16in(const __half* A, const BlockQ6_K* B, float* C, int M, int K, int N, int blocksPerRow) { const int row = blockIdx.y; const int warp = threadIdx.y; const int lane = threadIdx.x; const int col = blockIdx.x * 8 + warp; if (row >= M) return; const bool colIn = (col < N); float sum = 0.0f; __shared__ float a_sh[256]; __shared__ float d_sh[8]; __shared__ signed char scales_sh[8][16]; __shared__ unsigned char ql_sh[8][128]; __shared__ unsigned char qh_sh[8][64]; for (int blk = 0; blk < blocksPerRow; blk++) { const int tid = warp * 32 + lane; const __half* aRow = A + row * K + blk * 256; a_sh[tid] = __half2float(aRow[tid]); if (colIn) { const BlockQ6_K* b = &B[col * blocksPerRow + blk]; if (lane == 0) { d_sh[warp] = fp16_to_fp32(b->d); } if (lane < 16) { scales_sh[warp][lane] = b->scales[lane]; } qh_sh[warp][lane] = b->qh[lane]; qh_sh[warp][lane + 32] = b->qh[lane + 32]; ql_sh[warp][lane] = b->ql[lane]; ql_sh[warp][lane + 32] = b->ql[lane + 32]; ql_sh[warp][lane + 64] = b->ql[lane + 64]; ql_sh[warp][lane + 96] = b->ql[lane + 96]; } __syncthreads(); if (colIn) { const float d = d_sh[warp]; #pragma unroll for (int i = 0; i < 8; i++) { const int idx = lane + (i * 32); const int is = idx >> 5; const int iq = idx & 31; const int qlIdx = (is >> 2) * 64 + (is & 1) * 32 + iq; const int qhIdx = (is >> 2) * 32 + iq; const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4); const unsigned char ql = ql_sh[warp][qlIdx]; const unsigned char qh = qh_sh[warp][qhIdx]; const int shift_ql = ((is & 3) < 2) ? 0 : 4; const int shift_qh = (is & 3) * 2; int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4); q -= 32; sum += a_sh[idx] * (d * (float)scales_sh[warp][scIdx] * (float)q); } } __syncthreads(); } for (int offset = 16; offset > 0; offset >>= 1) { sum += __shfl_down_sync(0xffffffff, sum, offset); } if (colIn && lane == 0) { C[row * N + col] = sum; } } int cuda_matmul_f16_q6k(const void* A, const void* B, float* C, int M, int K, int N) { int blocksPerRow = K / 256; dim3 threads(32, 8); dim3 blocks((N + 7) / 8, M); matmul_q6k_kernel_f16in<<>>((const __half*)A, (const BlockQ6_K*)B, C, M, K, N, blocksPerRow); CHECK_CUDA(cudaGetLastError()); return 0; }