| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295 |
- #include "cuda_common.cuh"
- #include <cuda_fp16.h>
- namespace {
- // Simple tiled GEMM kernels for correctness-first dense matmul.
- // These are used as the default dense GEMM path when CUTLASS is not built.
- constexpr int TILE = 16;
- __global__ void matmul_f32_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C,
- int M, int K, int N) {
- __shared__ float As[TILE][TILE];
- __shared__ float Bs[TILE][TILE];
- const int row = blockIdx.y * TILE + threadIdx.y;
- const int col = blockIdx.x * TILE + threadIdx.x;
- float acc = 0.0f;
- for (int t = 0; t < (K + TILE - 1) / TILE; ++t) {
- const int aCol = t * TILE + threadIdx.x;
- const int bRow = t * TILE + threadIdx.y;
- As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0f;
- Bs[threadIdx.y][threadIdx.x] = (bRow < K && col < N) ? B[bRow * N + col] : 0.0f;
- __syncthreads();
- #pragma unroll
- for (int i = 0; i < TILE; ++i) {
- acc += As[threadIdx.y][i] * Bs[i][threadIdx.x];
- }
- __syncthreads();
- }
- if (row < M && col < N) {
- C[row * N + col] = acc;
- }
- }
- // Computes C = A @ B^T where B is stored row-major [N, K].
- __global__ void matmul_f32_nt_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C,
- int M, int K, int N) {
- __shared__ float As[TILE][TILE];
- __shared__ float Bs[TILE][TILE];
- const int row = blockIdx.y * TILE + threadIdx.y;
- const int col = blockIdx.x * TILE + threadIdx.x; // maps to n
- float acc = 0.0f;
- for (int t = 0; t < (K + TILE - 1) / TILE; ++t) {
- const int aCol = t * TILE + threadIdx.x;
- const int bCol = t * TILE + threadIdx.y; // k index for B row
- As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0f;
- Bs[threadIdx.y][threadIdx.x] = (col < N && bCol < K) ? B[col * K + bCol] : 0.0f;
- __syncthreads();
- #pragma unroll
- for (int i = 0; i < TILE; ++i) {
- acc += As[threadIdx.y][i] * Bs[i][threadIdx.x];
- }
- __syncthreads();
- }
- if (row < M && col < N) {
- C[row * N + col] = acc;
- }
- }
- // Computes C = A @ B^T where A and B are stored as IEEE half in uint16.
- __global__ void matmul_f16_nt_kernel(const __half* __restrict__ A, const __half* __restrict__ B, float* __restrict__ C,
- int M, int K, int N) {
- __shared__ __half As[TILE][TILE];
- __shared__ __half Bs[TILE][TILE];
- const int row = blockIdx.y * TILE + threadIdx.y;
- const int col = blockIdx.x * TILE + threadIdx.x;
- float acc = 0.0f;
- for (int t = 0; t < (K + TILE - 1) / TILE; ++t) {
- const int aCol = t * TILE + threadIdx.x;
- const int bCol = t * TILE + threadIdx.y;
- As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : __float2half(0.0f);
- Bs[threadIdx.y][threadIdx.x] = (col < N && bCol < K) ? B[col * K + bCol] : __float2half(0.0f);
- __syncthreads();
- #pragma unroll
- for (int i = 0; i < TILE; ++i) {
- acc += __half2float(As[threadIdx.y][i]) * __half2float(Bs[i][threadIdx.x]);
- }
- __syncthreads();
- }
- if (row < M && col < N) {
- C[row * N + col] = acc;
- }
- }
- } // namespace
- __global__ void matmul_q5k_kernel(float* A, const BlockQ5_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- // row is uniform across the block, so an early return here is safe.
- if (row >= M) return;
- // col is warp-specific. Do NOT early-return on col>=N because we use __syncthreads().
- const bool colIn = (col < N);
- float sum = 0.0f;
- // Cache the A tile (256 floats) once per block so the 8 warps (8 columns) reuse it.
- __shared__ float a_sh[256];
- __shared__ unsigned char sc_sh[8][8];
- __shared__ unsigned char m_sh[8][8];
- __shared__ float ds_sh[8][8];
- __shared__ float dm_sh[8][8];
- __shared__ float d_sh[8];
- __shared__ float dmin_sh[8];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- // Cache A tile once per block (256 floats). Each thread loads one element.
- const int tid = warp * 32 + lane;
- const float* aRow = A + row * K + blk * 256;
- a_sh[tid] = aRow[tid];
- if (colIn) {
- const BlockQ5_K* b = &B[col * blocksPerRow + blk];
- if (lane < 8) {
- unsigned char sc;
- unsigned char mn;
- if (lane < 4) {
- sc = b->scales[lane] & 63;
- mn = b->scales[lane + 4] & 63;
- } else {
- const int j = lane;
- sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
- mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
- }
- sc_sh[warp][lane] = sc;
- m_sh[warp][lane] = mn;
- }
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- dmin_sh[warp] = fp16_to_fp32(b->dmin);
- }
- }
- // Ensure all warps have finished loading this block's a_sh before use,
- // and that no warp overwrites a_sh while others are still reading it.
- __syncthreads();
- if (colIn) {
- const BlockQ5_K* b = &B[col * blocksPerRow + blk];
- // Precompute per-group multipliers once (one lane per group).
- if (lane < 8) {
- const float d = d_sh[warp];
- const float dmin = dmin_sh[warp];
- const unsigned char sc = sc_sh[warp][lane];
- const unsigned char mn = m_sh[warp][lane];
- ds_sh[warp][lane] = d * (float)sc;
- dm_sh[warp][lane] = dmin * (float)mn;
- }
- __syncwarp();
- const unsigned char hb = b->qh[lane];
- #pragma unroll
- for (int p = 0; p < 4; p++) {
- const unsigned char qs = b->qs[p * 32 + lane];
- int q0 = qs & 0xF;
- int q1 = qs >> 4;
- q0 += ((hb >> (2 * p)) & 1) << 4;
- q1 += ((hb >> (2 * p + 1)) & 1) << 4;
- const int idx0 = p * 64 + lane;
- const int idx1 = idx0 + 32;
- const int g0 = 2 * p;
- const int g1 = g0 + 1;
- const float ds0 = ds_sh[warp][g0];
- const float dm0 = dm_sh[warp][g0];
- const float ds1 = ds_sh[warp][g1];
- const float dm1 = dm_sh[warp][g1];
- sum += a_sh[idx0] * ((float)q0 * ds0 - dm0);
- sum += a_sh[idx1] * ((float)q1 * ds1 - dm1);
- }
- }
- __syncthreads();
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f32_q5k(float* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- matmul_q5k_kernel<<<blocks, threads>>>(A, (const BlockQ5_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- int cuda_matmul_f32(float* A, float* B, float* C, int M, int K, int N) {
- if (M <= 0 || N <= 0 || K <= 0) return 0;
- dim3 threads(TILE, TILE);
- dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE);
- matmul_f32_kernel<<<blocks, threads>>>(A, B, C, M, K, N);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- int cuda_matmul_f32_nt(float* A, float* B, float* C, int M, int K, int N) {
- if (M <= 0 || N <= 0 || K <= 0) return 0;
- dim3 threads(TILE, TILE);
- dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE);
- matmul_f32_nt_kernel<<<blocks, threads>>>(A, B, C, M, K, N);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- int cuda_matmul_f16_nt(const unsigned short* A, const unsigned short* B, float* C, int M, int K, int N) {
- if (M <= 0 || N <= 0 || K <= 0) return 0;
- dim3 threads(TILE, TILE);
- dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE);
- matmul_f16_nt_kernel<<<blocks, threads>>>(reinterpret_cast<const __half*>(A), reinterpret_cast<const __half*>(B), C, M, K, N);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // Fused Q8_K MatMul Kernel (tiled)
- // C[m,n] = sum_k A[m,k] * dequant(B[n,k])
- // Uses shared memory tiles to reduce global memory pressure.
- // ============================================================
- __global__ void matmul_q8k_kernel(float* A, const BlockQ8_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ float d_sh[8];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- const int tid = warp * 32 + lane;
- const float* aRow = A + row * K + blk * 256;
- a_sh[tid] = aRow[tid];
- if (colIn && lane == 0) {
- d_sh[warp] = B[col * blocksPerRow + blk].d;
- }
- __syncthreads();
- if (colIn) {
- const BlockQ8_K* b = &B[col * blocksPerRow + blk];
- const float d = d_sh[warp];
- // Each lane handles 8 weights in the 256-wide block.
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int idx = lane + (i * 32); // 0..255
- const float w = d * (float)((int)b->qs[idx]);
- sum += a_sh[idx] * w;
- }
- }
- __syncthreads();
- }
- // Warp reduction
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f32_q8k(float* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
-
- matmul_q8k_kernel<<<blocks, threads>>>(A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- int cuda_matmul_f32_q8k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms) {
- cudaEvent_t evStart;
- cudaEvent_t evStop;
- CHECK_CUDA(cudaEventCreate(&evStart));
- CHECK_CUDA(cudaEventCreate(&evStop));
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- CHECK_CUDA(cudaEventRecord(evStart));
- matmul_q8k_kernel<<<blocks, threads>>>(A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaEventRecord(evStop));
- CHECK_CUDA(cudaEventSynchronize(evStop));
- float elapsed = 0.0f;
- CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
- if (ms != NULL) {
- *ms = elapsed;
- }
- CHECK_CUDA(cudaEventDestroy(evStart));
- CHECK_CUDA(cudaEventDestroy(evStop));
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // FP16 Input Variants - 2x memory bandwidth for activations
- // Input A is FP16, dequantized weights computed in FP32,
- // accumulation in FP32, output FP32.
- // ============================================================
- __global__ void matmul_q8k_kernel_f16in(const __half* A, const BlockQ8_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ float d_sh[8];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- const int tid = warp * 32 + lane;
- const __half* aRow = A + row * K + blk * 256;
- // Load FP16, convert to FP32 in shared memory
- a_sh[tid] = __half2float(aRow[tid]);
- if (colIn && lane == 0) {
- d_sh[warp] = B[col * blocksPerRow + blk].d;
- }
- __syncthreads();
- if (colIn) {
- const BlockQ8_K* b = &B[col * blocksPerRow + blk];
- const float d = d_sh[warp];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int idx = lane + (i * 32);
- const float w = d * (float)((int)b->qs[idx]);
- sum += a_sh[idx] * w;
- }
- }
- __syncthreads();
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f16_q8k(const void* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
-
- matmul_q8k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // Fused Q4_K MatMul Kernel - simplified version
- // For full performance, would need shared memory tiling
- // ============================================================
- __global__ void matmul_q4k_kernel(float* A, const BlockQ4_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ unsigned char sc_sh[8][8];
- __shared__ unsigned char m_sh[8][8];
- __shared__ float ds_sh[8][8];
- __shared__ float dm_sh[8][8];
- __shared__ float d_sh[8];
- __shared__ float dmin_sh[8];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- const int tid = warp * 32 + lane;
- const float* aRow = A + row * K + blk * 256;
- a_sh[tid] = aRow[tid];
- if (colIn) {
- const BlockQ4_K* b = &B[col * blocksPerRow + blk];
- // Parallel unpack scale/min for groups 0..7 (one lane per group).
- if (lane < 8) {
- unsigned char sc;
- unsigned char mn;
- if (lane < 4) {
- sc = b->scales[lane] & 63;
- mn = b->scales[lane + 4] & 63;
- } else {
- const int j = lane;
- sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
- mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
- }
- sc_sh[warp][lane] = sc;
- m_sh[warp][lane] = mn;
- }
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- dmin_sh[warp] = fp16_to_fp32(b->dmin);
- }
- }
- __syncthreads();
- if (colIn) {
- const BlockQ4_K* b = &B[col * blocksPerRow + blk];
- // Precompute per-group float multipliers once.
- if (lane < 8) {
- const float d = d_sh[warp];
- const float dmin = dmin_sh[warp];
- const unsigned char sc = sc_sh[warp][lane];
- const unsigned char mn = m_sh[warp][lane];
- ds_sh[warp][lane] = d * (float)sc;
- dm_sh[warp][lane] = dmin * (float)mn;
- }
- __syncwarp();
- const float ds0 = ds_sh[warp][0];
- const float dm0 = dm_sh[warp][0];
- const float ds1 = ds_sh[warp][1];
- const float dm1 = dm_sh[warp][1];
- const float ds2 = ds_sh[warp][2];
- const float dm2 = dm_sh[warp][2];
- const float ds3 = ds_sh[warp][3];
- const float dm3 = dm_sh[warp][3];
- const float ds4 = ds_sh[warp][4];
- const float dm4 = dm_sh[warp][4];
- const float ds5 = ds_sh[warp][5];
- const float dm5 = dm_sh[warp][5];
- const float ds6 = ds_sh[warp][6];
- const float dm6 = dm_sh[warp][6];
- const float ds7 = ds_sh[warp][7];
- const float dm7 = dm_sh[warp][7];
- // Each lane processes 4 bytes; each byte contains 2 nibbles => 8 values per lane.
- // This halves qs loads and reduces bit ops.
- #pragma unroll
- for (int p = 0; p < 4; p++) {
- const unsigned char qs = b->qs[p * 32 + lane];
- const int q0 = qs & 0xF;
- const int q1 = qs >> 4;
- const int idx0 = p * 64 + lane; // group = 2*p
- const int idx1 = idx0 + 32; // group = 2*p + 1
- float dsA, dmA, dsB, dmB;
- if (p == 0) {
- dsA = ds0; dmA = dm0; dsB = ds1; dmB = dm1;
- } else if (p == 1) {
- dsA = ds2; dmA = dm2; dsB = ds3; dmB = dm3;
- } else if (p == 2) {
- dsA = ds4; dmA = dm4; dsB = ds5; dmB = dm5;
- } else {
- dsA = ds6; dmA = dm6; dsB = ds7; dmB = dm7;
- }
- const float w0 = (float)q0 * dsA - dmA;
- const float w1 = (float)q1 * dsB - dmB;
- sum += a_sh[idx0] * w0;
- sum += a_sh[idx1] * w1;
- }
- }
- __syncthreads();
- }
- // Warp reduction
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f32_q4k(float* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
-
- matmul_q4k_kernel<<<blocks, threads>>>(A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- int cuda_matmul_f32_q4k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms) {
- cudaEvent_t evStart;
- cudaEvent_t evStop;
- CHECK_CUDA(cudaEventCreate(&evStart));
- CHECK_CUDA(cudaEventCreate(&evStop));
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- CHECK_CUDA(cudaEventRecord(evStart));
- matmul_q4k_kernel<<<blocks, threads>>>(A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaEventRecord(evStop));
- CHECK_CUDA(cudaEventSynchronize(evStop));
- float elapsed = 0.0f;
- CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
- if (ms != NULL) {
- *ms = elapsed;
- }
- CHECK_CUDA(cudaEventDestroy(evStart));
- CHECK_CUDA(cudaEventDestroy(evStop));
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // Fused Q2_K MatMul Kernel - Naive
- // ============================================================
- __global__ void matmul_q2k_kernel(float* A, const BlockQ2_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ float d_sh[8];
- __shared__ float dmin_sh[8];
- __shared__ unsigned char scales_sh[8][16];
- __shared__ unsigned char qs_sh[8][64];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- // Cache A tile once per block (256 floats) to avoid redundant global loads.
- // Each thread loads one element: tid in [0,255].
- const int tid = warp * 32 + lane;
- const float* aRow = A + row * K + blk * 256;
- a_sh[tid] = aRow[tid];
- if (colIn) {
- const BlockQ2_K* b = &B[col * blocksPerRow + blk];
- // Cooperative per-warp cache.
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- dmin_sh[warp] = fp16_to_fp32(b->dmin);
- }
- if (lane < 16) {
- scales_sh[warp][lane] = b->scales[lane];
- }
- // Load 64 bytes qs with 32 lanes.
- qs_sh[warp][lane] = b->qs[lane];
- qs_sh[warp][lane + 32] = b->qs[lane + 32];
- }
- __syncthreads();
- if (colIn) {
- const float d = d_sh[warp];
- const float dmin = dmin_sh[warp];
- // Each lane handles 8 values.
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int idx = lane + (i * 32); // 0..255
- const int is = idx >> 5; // 0..7
- const int iq = idx & 31; // 0..31
- const int qsIdx = (is >> 2) * 32 + iq;
- const int shift = (is & 3) * 2;
- const int val = (qs_sh[warp][qsIdx] >> shift) & 3;
- const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
- const unsigned char sc = scales_sh[warp][scIdx];
- const float dl = d * (float)(sc & 0xF);
- const float ml = dmin * (float)(sc >> 4);
- const float w = dl * (float)val - ml;
- sum += a_sh[idx] * w;
- }
- }
- // Ensure all warps finished reading this block's a_sh before it is overwritten.
- __syncthreads();
- }
- // Warp reduction
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f32_q2k(float* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
-
- matmul_q2k_kernel<<<blocks, threads>>>(A, (const BlockQ2_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // Fused Q3_K MatMul Kernel
- // ============================================================
- __global__ void matmul_q3k_kernel(float* A, const BlockQ3_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ float d_sh[8];
- __shared__ unsigned char scales_sh[8][12];
- __shared__ unsigned char qs_sh[8][64];
- __shared__ unsigned char hmask_sh[8][32];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- // Cache A tile once per block (256 floats). Each thread loads one element.
- const int tid = warp * 32 + lane;
- const float* aRow = A + row * K + blk * 256;
- a_sh[tid] = aRow[tid];
- if (colIn) {
- const BlockQ3_K* b = &B[col * blocksPerRow + blk];
- // Cache quant block bytes.
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- }
- if (lane < 12) {
- scales_sh[warp][lane] = b->scales[lane];
- }
- // qs: 64 bytes
- qs_sh[warp][lane] = b->qs[lane];
- qs_sh[warp][lane + 32] = b->qs[lane + 32];
- // hmask: 32 bytes
- hmask_sh[warp][lane] = b->hmask[lane];
- }
- __syncthreads();
- if (colIn) {
- const float d = d_sh[warp];
- // Each lane handles 8 elements.
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int idx = lane + (i * 32); // 0..255
- const int is = idx >> 5; // 0..7
- const int iq = idx & 31; // 0..31
- const int qsIdx = (is >> 2) * 32 + iq;
- const int shift = (is & 3) * 2;
- int qv = (qs_sh[warp][qsIdx] >> shift) & 0x3;
- const unsigned char m = (unsigned char)(1 << ((is >> 2) * 4 + (is & 3)));
- if ((hmask_sh[warp][iq] & m) == 0) {
- qv -= 4;
- }
- const int sIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4); // 0..15
- unsigned char sc;
- if (sIdx < 8) {
- sc = scales_sh[warp][sIdx] & 0xF;
- } else {
- sc = scales_sh[warp][sIdx - 8] >> 4;
- }
- sc |= ((scales_sh[warp][8 + (sIdx & 3)] >> (2 * (sIdx >> 2))) & 0x3) << 4;
- const float scale = (float)((int)((signed char)sc) - 32);
- const float w = d * scale * (float)qv;
- sum += a_sh[idx] * w;
- }
- }
- __syncthreads();
- }
- // Warp reduction
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f32_q3k(float* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- matmul_q3k_kernel<<<blocks, threads>>>(A, (const BlockQ3_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // Fused Q6_K MatMul Kernel
- // ============================================================
- __global__ void matmul_q6k_kernel(float* A, const BlockQ6_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ float d_sh[8];
- __shared__ signed char scales_sh[8][16];
- __shared__ unsigned char ql_sh[8][128];
- __shared__ unsigned char qh_sh[8][64];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- // Cache A tile once per block.
- const int tid = warp * 32 + lane;
- const float* aRow = A + row * K + blk * 256;
- a_sh[tid] = aRow[tid];
- if (colIn) {
- const BlockQ6_K* b = &B[col * blocksPerRow + blk];
- // Cache quant block bytes.
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- }
- if (lane < 16) {
- scales_sh[warp][lane] = b->scales[lane];
- }
- // qh: 64 bytes
- qh_sh[warp][lane] = b->qh[lane];
- qh_sh[warp][lane + 32] = b->qh[lane + 32];
- // ql: 128 bytes
- ql_sh[warp][lane] = b->ql[lane];
- ql_sh[warp][lane + 32] = b->ql[lane + 32];
- ql_sh[warp][lane + 64] = b->ql[lane + 64];
- ql_sh[warp][lane + 96] = b->ql[lane + 96];
- }
- __syncthreads();
- if (colIn) {
- const float d = d_sh[warp];
- // Each lane handles 8 elements.
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int idx = lane + (i * 32); // 0..255
- const int is = idx >> 5; // 0..7
- const int iq = idx & 31; // 0..31
- const int qlIdx = (is >> 2) * 64 + (is & 1) * 32 + iq;
- const int qhIdx = (is >> 2) * 32 + iq;
- const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
- const unsigned char ql = ql_sh[warp][qlIdx];
- const unsigned char qh = qh_sh[warp][qhIdx];
- const int shift_ql = ((is & 3) < 2) ? 0 : 4;
- const int shift_qh = (is & 3) * 2;
- int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4);
- q -= 32;
- const float w = d * (float)scales_sh[warp][scIdx] * (float)q;
- sum += a_sh[idx] * w;
- }
- }
- __syncthreads();
- }
- // Warp reduction
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f32_q6k(float* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- matmul_q6k_kernel<<<blocks, threads>>>(A, (const BlockQ6_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- // ============================================================
- // FP16 Input Variants for Q4K, Q5K, Q2K, Q3K, Q6K
- // Same logic as FP32 versions but load A as FP16
- // ============================================================
- __global__ void matmul_q4k_kernel_f16in(const __half* A, const BlockQ4_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ unsigned char sc_sh[8][8];
- __shared__ unsigned char m_sh[8][8];
- __shared__ float ds_sh[8][8];
- __shared__ float dm_sh[8][8];
- __shared__ float d_sh[8];
- __shared__ float dmin_sh[8];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- const int tid = warp * 32 + lane;
- const __half* aRow = A + row * K + blk * 256;
- a_sh[tid] = __half2float(aRow[tid]);
- if (colIn) {
- const BlockQ4_K* b = &B[col * blocksPerRow + blk];
- if (lane < 8) {
- unsigned char sc, mn;
- if (lane < 4) {
- sc = b->scales[lane] & 63;
- mn = b->scales[lane + 4] & 63;
- } else {
- const int j = lane;
- sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
- mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
- }
- sc_sh[warp][lane] = sc;
- m_sh[warp][lane] = mn;
- }
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- dmin_sh[warp] = fp16_to_fp32(b->dmin);
- }
- }
- __syncthreads();
- if (colIn) {
- const BlockQ4_K* b = &B[col * blocksPerRow + blk];
- if (lane < 8) {
- ds_sh[warp][lane] = d_sh[warp] * (float)sc_sh[warp][lane];
- dm_sh[warp][lane] = dmin_sh[warp] * (float)m_sh[warp][lane];
- }
- __syncwarp();
- #pragma unroll
- for (int p = 0; p < 4; p++) {
- const unsigned char qs = b->qs[p * 32 + lane];
- const int q0 = qs & 0xF;
- const int q1 = qs >> 4;
- const int idx0 = p * 64 + lane;
- const int idx1 = idx0 + 32;
- const int g0 = 2 * p;
- const int g1 = g0 + 1;
- sum += a_sh[idx0] * ((float)q0 * ds_sh[warp][g0] - dm_sh[warp][g0]);
- sum += a_sh[idx1] * ((float)q1 * ds_sh[warp][g1] - dm_sh[warp][g1]);
- }
- }
- __syncthreads();
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f16_q4k(const void* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- matmul_q4k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- __global__ void matmul_q5k_kernel_f16in(const __half* A, const BlockQ5_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ unsigned char sc_sh[8][8];
- __shared__ unsigned char m_sh[8][8];
- __shared__ float ds_sh[8][8];
- __shared__ float dm_sh[8][8];
- __shared__ float d_sh[8];
- __shared__ float dmin_sh[8];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- const int tid = warp * 32 + lane;
- const __half* aRow = A + row * K + blk * 256;
- a_sh[tid] = __half2float(aRow[tid]);
- if (colIn) {
- const BlockQ5_K* b = &B[col * blocksPerRow + blk];
- if (lane < 8) {
- unsigned char sc, mn;
- if (lane < 4) {
- sc = b->scales[lane] & 63;
- mn = b->scales[lane + 4] & 63;
- } else {
- const int j = lane;
- sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
- mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
- }
- sc_sh[warp][lane] = sc;
- m_sh[warp][lane] = mn;
- }
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- dmin_sh[warp] = fp16_to_fp32(b->dmin);
- }
- }
- __syncthreads();
- if (colIn) {
- const BlockQ5_K* b = &B[col * blocksPerRow + blk];
- if (lane < 8) {
- ds_sh[warp][lane] = d_sh[warp] * (float)sc_sh[warp][lane];
- dm_sh[warp][lane] = dmin_sh[warp] * (float)m_sh[warp][lane];
- }
- __syncwarp();
- const unsigned char hb = b->qh[lane];
- #pragma unroll
- for (int p = 0; p < 4; p++) {
- const unsigned char qs = b->qs[p * 32 + lane];
- int q0 = qs & 0xF;
- int q1 = qs >> 4;
- q0 += ((hb >> (2 * p)) & 1) << 4;
- q1 += ((hb >> (2 * p + 1)) & 1) << 4;
- const int idx0 = p * 64 + lane;
- const int idx1 = idx0 + 32;
- const int g0 = 2 * p;
- const int g1 = g0 + 1;
- sum += a_sh[idx0] * ((float)q0 * ds_sh[warp][g0] - dm_sh[warp][g0]);
- sum += a_sh[idx1] * ((float)q1 * ds_sh[warp][g1] - dm_sh[warp][g1]);
- }
- }
- __syncthreads();
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f16_q5k(const void* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- matmul_q5k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ5_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- __global__ void matmul_q2k_kernel_f16in(const __half* A, const BlockQ2_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ float d_sh[8];
- __shared__ float dmin_sh[8];
- __shared__ unsigned char scales_sh[8][16];
- __shared__ unsigned char qs_sh[8][64];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- const int tid = warp * 32 + lane;
- const __half* aRow = A + row * K + blk * 256;
- a_sh[tid] = __half2float(aRow[tid]);
- if (colIn) {
- const BlockQ2_K* b = &B[col * blocksPerRow + blk];
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- dmin_sh[warp] = fp16_to_fp32(b->dmin);
- }
- if (lane < 16) {
- scales_sh[warp][lane] = b->scales[lane];
- }
- qs_sh[warp][lane] = b->qs[lane];
- qs_sh[warp][lane + 32] = b->qs[lane + 32];
- }
- __syncthreads();
- if (colIn) {
- const float d = d_sh[warp];
- const float dmin = dmin_sh[warp];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int idx = lane + (i * 32);
- const int is = idx >> 5;
- const int iq = idx & 31;
- const int qsIdx = (is >> 2) * 32 + iq;
- const int shift = (is & 3) * 2;
- const int val = (qs_sh[warp][qsIdx] >> shift) & 3;
- const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
- const unsigned char sc = scales_sh[warp][scIdx];
- const float dl = d * (float)(sc & 0xF);
- const float ml = dmin * (float)(sc >> 4);
- sum += a_sh[idx] * (dl * (float)val - ml);
- }
- }
- __syncthreads();
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f16_q2k(const void* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- matmul_q2k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ2_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- __global__ void matmul_q3k_kernel_f16in(const __half* A, const BlockQ3_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ float d_sh[8];
- __shared__ unsigned char scales_sh[8][12];
- __shared__ unsigned char qs_sh[8][64];
- __shared__ unsigned char hmask_sh[8][32];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- const int tid = warp * 32 + lane;
- const __half* aRow = A + row * K + blk * 256;
- a_sh[tid] = __half2float(aRow[tid]);
- if (colIn) {
- const BlockQ3_K* b = &B[col * blocksPerRow + blk];
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- }
- if (lane < 12) {
- scales_sh[warp][lane] = b->scales[lane];
- }
- qs_sh[warp][lane] = b->qs[lane];
- qs_sh[warp][lane + 32] = b->qs[lane + 32];
- hmask_sh[warp][lane] = b->hmask[lane];
- }
- __syncthreads();
- if (colIn) {
- const float d = d_sh[warp];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int idx = lane + (i * 32);
- const int is = idx >> 5;
- const int iq = idx & 31;
- const int qsIdx = (is >> 2) * 32 + iq;
- const int shift = (is & 3) * 2;
- int qv = (qs_sh[warp][qsIdx] >> shift) & 0x3;
- const unsigned char m = (unsigned char)(1 << ((is >> 2) * 4 + (is & 3)));
- if ((hmask_sh[warp][iq] & m) == 0) {
- qv -= 4;
- }
- const int sIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
- unsigned char sc;
- if (sIdx < 8) {
- sc = scales_sh[warp][sIdx] & 0xF;
- } else {
- sc = scales_sh[warp][sIdx - 8] >> 4;
- }
- sc |= ((scales_sh[warp][8 + (sIdx & 3)] >> (2 * (sIdx >> 2))) & 0x3) << 4;
- const float scale = (float)((int)((signed char)sc) - 32);
- sum += a_sh[idx] * (d * scale * (float)qv);
- }
- }
- __syncthreads();
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f16_q3k(const void* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- matmul_q3k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ3_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
- __global__ void matmul_q6k_kernel_f16in(const __half* A, const BlockQ6_K* B, float* C,
- int M, int K, int N, int blocksPerRow) {
- const int row = blockIdx.y;
- const int warp = threadIdx.y;
- const int lane = threadIdx.x;
- const int col = blockIdx.x * 8 + warp;
- if (row >= M) return;
- const bool colIn = (col < N);
- float sum = 0.0f;
- __shared__ float a_sh[256];
- __shared__ float d_sh[8];
- __shared__ signed char scales_sh[8][16];
- __shared__ unsigned char ql_sh[8][128];
- __shared__ unsigned char qh_sh[8][64];
- for (int blk = 0; blk < blocksPerRow; blk++) {
- const int tid = warp * 32 + lane;
- const __half* aRow = A + row * K + blk * 256;
- a_sh[tid] = __half2float(aRow[tid]);
- if (colIn) {
- const BlockQ6_K* b = &B[col * blocksPerRow + blk];
- if (lane == 0) {
- d_sh[warp] = fp16_to_fp32(b->d);
- }
- if (lane < 16) {
- scales_sh[warp][lane] = b->scales[lane];
- }
- qh_sh[warp][lane] = b->qh[lane];
- qh_sh[warp][lane + 32] = b->qh[lane + 32];
- ql_sh[warp][lane] = b->ql[lane];
- ql_sh[warp][lane + 32] = b->ql[lane + 32];
- ql_sh[warp][lane + 64] = b->ql[lane + 64];
- ql_sh[warp][lane + 96] = b->ql[lane + 96];
- }
- __syncthreads();
- if (colIn) {
- const float d = d_sh[warp];
- #pragma unroll
- for (int i = 0; i < 8; i++) {
- const int idx = lane + (i * 32);
- const int is = idx >> 5;
- const int iq = idx & 31;
- const int qlIdx = (is >> 2) * 64 + (is & 1) * 32 + iq;
- const int qhIdx = (is >> 2) * 32 + iq;
- const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
- const unsigned char ql = ql_sh[warp][qlIdx];
- const unsigned char qh = qh_sh[warp][qhIdx];
- const int shift_ql = ((is & 3) < 2) ? 0 : 4;
- const int shift_qh = (is & 3) * 2;
- int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4);
- q -= 32;
- sum += a_sh[idx] * (d * (float)scales_sh[warp][scIdx] * (float)q);
- }
- }
- __syncthreads();
- }
- for (int offset = 16; offset > 0; offset >>= 1) {
- sum += __shfl_down_sync(0xffffffff, sum, offset);
- }
- if (colIn && lane == 0) {
- C[row * N + col] = sum;
- }
- }
- int cuda_matmul_f16_q6k(const void* A, const void* B, float* C, int M, int K, int N) {
- int blocksPerRow = K / 256;
- dim3 threads(32, 8);
- dim3 blocks((N + 7) / 8, M);
- matmul_q6k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ6_K*)B, C, M, K, N, blocksPerRow);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
|