cuda_matmul.cu 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295
  1. #include "cuda_common.cuh"
  2. #include <cuda_fp16.h>
  3. namespace {
  4. // Simple tiled GEMM kernels for correctness-first dense matmul.
  5. // These are used as the default dense GEMM path when CUTLASS is not built.
  6. constexpr int TILE = 16;
  7. __global__ void matmul_f32_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C,
  8. int M, int K, int N) {
  9. __shared__ float As[TILE][TILE];
  10. __shared__ float Bs[TILE][TILE];
  11. const int row = blockIdx.y * TILE + threadIdx.y;
  12. const int col = blockIdx.x * TILE + threadIdx.x;
  13. float acc = 0.0f;
  14. for (int t = 0; t < (K + TILE - 1) / TILE; ++t) {
  15. const int aCol = t * TILE + threadIdx.x;
  16. const int bRow = t * TILE + threadIdx.y;
  17. As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0f;
  18. Bs[threadIdx.y][threadIdx.x] = (bRow < K && col < N) ? B[bRow * N + col] : 0.0f;
  19. __syncthreads();
  20. #pragma unroll
  21. for (int i = 0; i < TILE; ++i) {
  22. acc += As[threadIdx.y][i] * Bs[i][threadIdx.x];
  23. }
  24. __syncthreads();
  25. }
  26. if (row < M && col < N) {
  27. C[row * N + col] = acc;
  28. }
  29. }
  30. // Computes C = A @ B^T where B is stored row-major [N, K].
  31. __global__ void matmul_f32_nt_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C,
  32. int M, int K, int N) {
  33. __shared__ float As[TILE][TILE];
  34. __shared__ float Bs[TILE][TILE];
  35. const int row = blockIdx.y * TILE + threadIdx.y;
  36. const int col = blockIdx.x * TILE + threadIdx.x; // maps to n
  37. float acc = 0.0f;
  38. for (int t = 0; t < (K + TILE - 1) / TILE; ++t) {
  39. const int aCol = t * TILE + threadIdx.x;
  40. const int bCol = t * TILE + threadIdx.y; // k index for B row
  41. As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0f;
  42. Bs[threadIdx.y][threadIdx.x] = (col < N && bCol < K) ? B[col * K + bCol] : 0.0f;
  43. __syncthreads();
  44. #pragma unroll
  45. for (int i = 0; i < TILE; ++i) {
  46. acc += As[threadIdx.y][i] * Bs[i][threadIdx.x];
  47. }
  48. __syncthreads();
  49. }
  50. if (row < M && col < N) {
  51. C[row * N + col] = acc;
  52. }
  53. }
  54. // Computes C = A @ B^T where A and B are stored as IEEE half in uint16.
  55. __global__ void matmul_f16_nt_kernel(const __half* __restrict__ A, const __half* __restrict__ B, float* __restrict__ C,
  56. int M, int K, int N) {
  57. __shared__ __half As[TILE][TILE];
  58. __shared__ __half Bs[TILE][TILE];
  59. const int row = blockIdx.y * TILE + threadIdx.y;
  60. const int col = blockIdx.x * TILE + threadIdx.x;
  61. float acc = 0.0f;
  62. for (int t = 0; t < (K + TILE - 1) / TILE; ++t) {
  63. const int aCol = t * TILE + threadIdx.x;
  64. const int bCol = t * TILE + threadIdx.y;
  65. As[threadIdx.y][threadIdx.x] = (row < M && aCol < K) ? A[row * K + aCol] : __float2half(0.0f);
  66. Bs[threadIdx.y][threadIdx.x] = (col < N && bCol < K) ? B[col * K + bCol] : __float2half(0.0f);
  67. __syncthreads();
  68. #pragma unroll
  69. for (int i = 0; i < TILE; ++i) {
  70. acc += __half2float(As[threadIdx.y][i]) * __half2float(Bs[i][threadIdx.x]);
  71. }
  72. __syncthreads();
  73. }
  74. if (row < M && col < N) {
  75. C[row * N + col] = acc;
  76. }
  77. }
  78. } // namespace
  79. __global__ void matmul_q5k_kernel(float* A, const BlockQ5_K* B, float* C,
  80. int M, int K, int N, int blocksPerRow) {
  81. const int row = blockIdx.y;
  82. const int warp = threadIdx.y;
  83. const int lane = threadIdx.x;
  84. const int col = blockIdx.x * 8 + warp;
  85. // row is uniform across the block, so an early return here is safe.
  86. if (row >= M) return;
  87. // col is warp-specific. Do NOT early-return on col>=N because we use __syncthreads().
  88. const bool colIn = (col < N);
  89. float sum = 0.0f;
  90. // Cache the A tile (256 floats) once per block so the 8 warps (8 columns) reuse it.
  91. __shared__ float a_sh[256];
  92. __shared__ unsigned char sc_sh[8][8];
  93. __shared__ unsigned char m_sh[8][8];
  94. __shared__ float ds_sh[8][8];
  95. __shared__ float dm_sh[8][8];
  96. __shared__ float d_sh[8];
  97. __shared__ float dmin_sh[8];
  98. for (int blk = 0; blk < blocksPerRow; blk++) {
  99. // Cache A tile once per block (256 floats). Each thread loads one element.
  100. const int tid = warp * 32 + lane;
  101. const float* aRow = A + row * K + blk * 256;
  102. a_sh[tid] = aRow[tid];
  103. if (colIn) {
  104. const BlockQ5_K* b = &B[col * blocksPerRow + blk];
  105. if (lane < 8) {
  106. unsigned char sc;
  107. unsigned char mn;
  108. if (lane < 4) {
  109. sc = b->scales[lane] & 63;
  110. mn = b->scales[lane + 4] & 63;
  111. } else {
  112. const int j = lane;
  113. sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
  114. mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
  115. }
  116. sc_sh[warp][lane] = sc;
  117. m_sh[warp][lane] = mn;
  118. }
  119. if (lane == 0) {
  120. d_sh[warp] = fp16_to_fp32(b->d);
  121. dmin_sh[warp] = fp16_to_fp32(b->dmin);
  122. }
  123. }
  124. // Ensure all warps have finished loading this block's a_sh before use,
  125. // and that no warp overwrites a_sh while others are still reading it.
  126. __syncthreads();
  127. if (colIn) {
  128. const BlockQ5_K* b = &B[col * blocksPerRow + blk];
  129. // Precompute per-group multipliers once (one lane per group).
  130. if (lane < 8) {
  131. const float d = d_sh[warp];
  132. const float dmin = dmin_sh[warp];
  133. const unsigned char sc = sc_sh[warp][lane];
  134. const unsigned char mn = m_sh[warp][lane];
  135. ds_sh[warp][lane] = d * (float)sc;
  136. dm_sh[warp][lane] = dmin * (float)mn;
  137. }
  138. __syncwarp();
  139. const unsigned char hb = b->qh[lane];
  140. #pragma unroll
  141. for (int p = 0; p < 4; p++) {
  142. const unsigned char qs = b->qs[p * 32 + lane];
  143. int q0 = qs & 0xF;
  144. int q1 = qs >> 4;
  145. q0 += ((hb >> (2 * p)) & 1) << 4;
  146. q1 += ((hb >> (2 * p + 1)) & 1) << 4;
  147. const int idx0 = p * 64 + lane;
  148. const int idx1 = idx0 + 32;
  149. const int g0 = 2 * p;
  150. const int g1 = g0 + 1;
  151. const float ds0 = ds_sh[warp][g0];
  152. const float dm0 = dm_sh[warp][g0];
  153. const float ds1 = ds_sh[warp][g1];
  154. const float dm1 = dm_sh[warp][g1];
  155. sum += a_sh[idx0] * ((float)q0 * ds0 - dm0);
  156. sum += a_sh[idx1] * ((float)q1 * ds1 - dm1);
  157. }
  158. }
  159. __syncthreads();
  160. }
  161. for (int offset = 16; offset > 0; offset >>= 1) {
  162. sum += __shfl_down_sync(0xffffffff, sum, offset);
  163. }
  164. if (colIn && lane == 0) {
  165. C[row * N + col] = sum;
  166. }
  167. }
  168. int cuda_matmul_f32_q5k(float* A, const void* B, float* C, int M, int K, int N) {
  169. int blocksPerRow = K / 256;
  170. dim3 threads(32, 8);
  171. dim3 blocks((N + 7) / 8, M);
  172. matmul_q5k_kernel<<<blocks, threads>>>(A, (const BlockQ5_K*)B, C, M, K, N, blocksPerRow);
  173. CHECK_CUDA(cudaGetLastError());
  174. return 0;
  175. }
  176. int cuda_matmul_f32(float* A, float* B, float* C, int M, int K, int N) {
  177. if (M <= 0 || N <= 0 || K <= 0) return 0;
  178. dim3 threads(TILE, TILE);
  179. dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE);
  180. matmul_f32_kernel<<<blocks, threads>>>(A, B, C, M, K, N);
  181. CHECK_CUDA(cudaGetLastError());
  182. return 0;
  183. }
  184. int cuda_matmul_f32_nt(float* A, float* B, float* C, int M, int K, int N) {
  185. if (M <= 0 || N <= 0 || K <= 0) return 0;
  186. dim3 threads(TILE, TILE);
  187. dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE);
  188. matmul_f32_nt_kernel<<<blocks, threads>>>(A, B, C, M, K, N);
  189. CHECK_CUDA(cudaGetLastError());
  190. return 0;
  191. }
  192. int cuda_matmul_f16_nt(const unsigned short* A, const unsigned short* B, float* C, int M, int K, int N) {
  193. if (M <= 0 || N <= 0 || K <= 0) return 0;
  194. dim3 threads(TILE, TILE);
  195. dim3 blocks((N + TILE - 1) / TILE, (M + TILE - 1) / TILE);
  196. matmul_f16_nt_kernel<<<blocks, threads>>>(reinterpret_cast<const __half*>(A), reinterpret_cast<const __half*>(B), C, M, K, N);
  197. CHECK_CUDA(cudaGetLastError());
  198. return 0;
  199. }
  200. // ============================================================
  201. // Fused Q8_K MatMul Kernel (tiled)
  202. // C[m,n] = sum_k A[m,k] * dequant(B[n,k])
  203. // Uses shared memory tiles to reduce global memory pressure.
  204. // ============================================================
  205. __global__ void matmul_q8k_kernel(float* A, const BlockQ8_K* B, float* C,
  206. int M, int K, int N, int blocksPerRow) {
  207. const int row = blockIdx.y;
  208. const int warp = threadIdx.y;
  209. const int lane = threadIdx.x;
  210. const int col = blockIdx.x * 8 + warp;
  211. if (row >= M) return;
  212. const bool colIn = (col < N);
  213. float sum = 0.0f;
  214. __shared__ float a_sh[256];
  215. __shared__ float d_sh[8];
  216. for (int blk = 0; blk < blocksPerRow; blk++) {
  217. const int tid = warp * 32 + lane;
  218. const float* aRow = A + row * K + blk * 256;
  219. a_sh[tid] = aRow[tid];
  220. if (colIn && lane == 0) {
  221. d_sh[warp] = B[col * blocksPerRow + blk].d;
  222. }
  223. __syncthreads();
  224. if (colIn) {
  225. const BlockQ8_K* b = &B[col * blocksPerRow + blk];
  226. const float d = d_sh[warp];
  227. // Each lane handles 8 weights in the 256-wide block.
  228. #pragma unroll
  229. for (int i = 0; i < 8; i++) {
  230. const int idx = lane + (i * 32); // 0..255
  231. const float w = d * (float)((int)b->qs[idx]);
  232. sum += a_sh[idx] * w;
  233. }
  234. }
  235. __syncthreads();
  236. }
  237. // Warp reduction
  238. for (int offset = 16; offset > 0; offset >>= 1) {
  239. sum += __shfl_down_sync(0xffffffff, sum, offset);
  240. }
  241. if (colIn && lane == 0) {
  242. C[row * N + col] = sum;
  243. }
  244. }
  245. int cuda_matmul_f32_q8k(float* A, const void* B, float* C, int M, int K, int N) {
  246. int blocksPerRow = K / 256;
  247. dim3 threads(32, 8);
  248. dim3 blocks((N + 7) / 8, M);
  249. matmul_q8k_kernel<<<blocks, threads>>>(A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow);
  250. CHECK_CUDA(cudaGetLastError());
  251. return 0;
  252. }
  253. int cuda_matmul_f32_q8k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms) {
  254. cudaEvent_t evStart;
  255. cudaEvent_t evStop;
  256. CHECK_CUDA(cudaEventCreate(&evStart));
  257. CHECK_CUDA(cudaEventCreate(&evStop));
  258. int blocksPerRow = K / 256;
  259. dim3 threads(32, 8);
  260. dim3 blocks((N + 7) / 8, M);
  261. CHECK_CUDA(cudaEventRecord(evStart));
  262. matmul_q8k_kernel<<<blocks, threads>>>(A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow);
  263. CHECK_CUDA(cudaEventRecord(evStop));
  264. CHECK_CUDA(cudaEventSynchronize(evStop));
  265. float elapsed = 0.0f;
  266. CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
  267. if (ms != NULL) {
  268. *ms = elapsed;
  269. }
  270. CHECK_CUDA(cudaEventDestroy(evStart));
  271. CHECK_CUDA(cudaEventDestroy(evStop));
  272. CHECK_CUDA(cudaGetLastError());
  273. return 0;
  274. }
  275. // ============================================================
  276. // FP16 Input Variants - 2x memory bandwidth for activations
  277. // Input A is FP16, dequantized weights computed in FP32,
  278. // accumulation in FP32, output FP32.
  279. // ============================================================
  280. __global__ void matmul_q8k_kernel_f16in(const __half* A, const BlockQ8_K* B, float* C,
  281. int M, int K, int N, int blocksPerRow) {
  282. const int row = blockIdx.y;
  283. const int warp = threadIdx.y;
  284. const int lane = threadIdx.x;
  285. const int col = blockIdx.x * 8 + warp;
  286. if (row >= M) return;
  287. const bool colIn = (col < N);
  288. float sum = 0.0f;
  289. __shared__ float a_sh[256];
  290. __shared__ float d_sh[8];
  291. for (int blk = 0; blk < blocksPerRow; blk++) {
  292. const int tid = warp * 32 + lane;
  293. const __half* aRow = A + row * K + blk * 256;
  294. // Load FP16, convert to FP32 in shared memory
  295. a_sh[tid] = __half2float(aRow[tid]);
  296. if (colIn && lane == 0) {
  297. d_sh[warp] = B[col * blocksPerRow + blk].d;
  298. }
  299. __syncthreads();
  300. if (colIn) {
  301. const BlockQ8_K* b = &B[col * blocksPerRow + blk];
  302. const float d = d_sh[warp];
  303. #pragma unroll
  304. for (int i = 0; i < 8; i++) {
  305. const int idx = lane + (i * 32);
  306. const float w = d * (float)((int)b->qs[idx]);
  307. sum += a_sh[idx] * w;
  308. }
  309. }
  310. __syncthreads();
  311. }
  312. for (int offset = 16; offset > 0; offset >>= 1) {
  313. sum += __shfl_down_sync(0xffffffff, sum, offset);
  314. }
  315. if (colIn && lane == 0) {
  316. C[row * N + col] = sum;
  317. }
  318. }
  319. int cuda_matmul_f16_q8k(const void* A, const void* B, float* C, int M, int K, int N) {
  320. int blocksPerRow = K / 256;
  321. dim3 threads(32, 8);
  322. dim3 blocks((N + 7) / 8, M);
  323. matmul_q8k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ8_K*)B, C, M, K, N, blocksPerRow);
  324. CHECK_CUDA(cudaGetLastError());
  325. return 0;
  326. }
  327. // ============================================================
  328. // Fused Q4_K MatMul Kernel - simplified version
  329. // For full performance, would need shared memory tiling
  330. // ============================================================
  331. __global__ void matmul_q4k_kernel(float* A, const BlockQ4_K* B, float* C,
  332. int M, int K, int N, int blocksPerRow) {
  333. const int row = blockIdx.y;
  334. const int warp = threadIdx.y;
  335. const int lane = threadIdx.x;
  336. const int col = blockIdx.x * 8 + warp;
  337. if (row >= M) return;
  338. const bool colIn = (col < N);
  339. float sum = 0.0f;
  340. __shared__ float a_sh[256];
  341. __shared__ unsigned char sc_sh[8][8];
  342. __shared__ unsigned char m_sh[8][8];
  343. __shared__ float ds_sh[8][8];
  344. __shared__ float dm_sh[8][8];
  345. __shared__ float d_sh[8];
  346. __shared__ float dmin_sh[8];
  347. for (int blk = 0; blk < blocksPerRow; blk++) {
  348. const int tid = warp * 32 + lane;
  349. const float* aRow = A + row * K + blk * 256;
  350. a_sh[tid] = aRow[tid];
  351. if (colIn) {
  352. const BlockQ4_K* b = &B[col * blocksPerRow + blk];
  353. // Parallel unpack scale/min for groups 0..7 (one lane per group).
  354. if (lane < 8) {
  355. unsigned char sc;
  356. unsigned char mn;
  357. if (lane < 4) {
  358. sc = b->scales[lane] & 63;
  359. mn = b->scales[lane + 4] & 63;
  360. } else {
  361. const int j = lane;
  362. sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
  363. mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
  364. }
  365. sc_sh[warp][lane] = sc;
  366. m_sh[warp][lane] = mn;
  367. }
  368. if (lane == 0) {
  369. d_sh[warp] = fp16_to_fp32(b->d);
  370. dmin_sh[warp] = fp16_to_fp32(b->dmin);
  371. }
  372. }
  373. __syncthreads();
  374. if (colIn) {
  375. const BlockQ4_K* b = &B[col * blocksPerRow + blk];
  376. // Precompute per-group float multipliers once.
  377. if (lane < 8) {
  378. const float d = d_sh[warp];
  379. const float dmin = dmin_sh[warp];
  380. const unsigned char sc = sc_sh[warp][lane];
  381. const unsigned char mn = m_sh[warp][lane];
  382. ds_sh[warp][lane] = d * (float)sc;
  383. dm_sh[warp][lane] = dmin * (float)mn;
  384. }
  385. __syncwarp();
  386. const float ds0 = ds_sh[warp][0];
  387. const float dm0 = dm_sh[warp][0];
  388. const float ds1 = ds_sh[warp][1];
  389. const float dm1 = dm_sh[warp][1];
  390. const float ds2 = ds_sh[warp][2];
  391. const float dm2 = dm_sh[warp][2];
  392. const float ds3 = ds_sh[warp][3];
  393. const float dm3 = dm_sh[warp][3];
  394. const float ds4 = ds_sh[warp][4];
  395. const float dm4 = dm_sh[warp][4];
  396. const float ds5 = ds_sh[warp][5];
  397. const float dm5 = dm_sh[warp][5];
  398. const float ds6 = ds_sh[warp][6];
  399. const float dm6 = dm_sh[warp][6];
  400. const float ds7 = ds_sh[warp][7];
  401. const float dm7 = dm_sh[warp][7];
  402. // Each lane processes 4 bytes; each byte contains 2 nibbles => 8 values per lane.
  403. // This halves qs loads and reduces bit ops.
  404. #pragma unroll
  405. for (int p = 0; p < 4; p++) {
  406. const unsigned char qs = b->qs[p * 32 + lane];
  407. const int q0 = qs & 0xF;
  408. const int q1 = qs >> 4;
  409. const int idx0 = p * 64 + lane; // group = 2*p
  410. const int idx1 = idx0 + 32; // group = 2*p + 1
  411. float dsA, dmA, dsB, dmB;
  412. if (p == 0) {
  413. dsA = ds0; dmA = dm0; dsB = ds1; dmB = dm1;
  414. } else if (p == 1) {
  415. dsA = ds2; dmA = dm2; dsB = ds3; dmB = dm3;
  416. } else if (p == 2) {
  417. dsA = ds4; dmA = dm4; dsB = ds5; dmB = dm5;
  418. } else {
  419. dsA = ds6; dmA = dm6; dsB = ds7; dmB = dm7;
  420. }
  421. const float w0 = (float)q0 * dsA - dmA;
  422. const float w1 = (float)q1 * dsB - dmB;
  423. sum += a_sh[idx0] * w0;
  424. sum += a_sh[idx1] * w1;
  425. }
  426. }
  427. __syncthreads();
  428. }
  429. // Warp reduction
  430. for (int offset = 16; offset > 0; offset >>= 1) {
  431. sum += __shfl_down_sync(0xffffffff, sum, offset);
  432. }
  433. if (colIn && lane == 0) {
  434. C[row * N + col] = sum;
  435. }
  436. }
  437. int cuda_matmul_f32_q4k(float* A, const void* B, float* C, int M, int K, int N) {
  438. int blocksPerRow = K / 256;
  439. dim3 threads(32, 8);
  440. dim3 blocks((N + 7) / 8, M);
  441. matmul_q4k_kernel<<<blocks, threads>>>(A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow);
  442. CHECK_CUDA(cudaGetLastError());
  443. return 0;
  444. }
  445. int cuda_matmul_f32_q4k_timed(float* A, const void* B, float* C, int M, int K, int N, float* ms) {
  446. cudaEvent_t evStart;
  447. cudaEvent_t evStop;
  448. CHECK_CUDA(cudaEventCreate(&evStart));
  449. CHECK_CUDA(cudaEventCreate(&evStop));
  450. int blocksPerRow = K / 256;
  451. dim3 threads(32, 8);
  452. dim3 blocks((N + 7) / 8, M);
  453. CHECK_CUDA(cudaEventRecord(evStart));
  454. matmul_q4k_kernel<<<blocks, threads>>>(A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow);
  455. CHECK_CUDA(cudaEventRecord(evStop));
  456. CHECK_CUDA(cudaEventSynchronize(evStop));
  457. float elapsed = 0.0f;
  458. CHECK_CUDA(cudaEventElapsedTime(&elapsed, evStart, evStop));
  459. if (ms != NULL) {
  460. *ms = elapsed;
  461. }
  462. CHECK_CUDA(cudaEventDestroy(evStart));
  463. CHECK_CUDA(cudaEventDestroy(evStop));
  464. CHECK_CUDA(cudaGetLastError());
  465. return 0;
  466. }
  467. // ============================================================
  468. // Fused Q2_K MatMul Kernel - Naive
  469. // ============================================================
  470. __global__ void matmul_q2k_kernel(float* A, const BlockQ2_K* B, float* C,
  471. int M, int K, int N, int blocksPerRow) {
  472. const int row = blockIdx.y;
  473. const int warp = threadIdx.y;
  474. const int lane = threadIdx.x;
  475. const int col = blockIdx.x * 8 + warp;
  476. if (row >= M) return;
  477. const bool colIn = (col < N);
  478. float sum = 0.0f;
  479. __shared__ float a_sh[256];
  480. __shared__ float d_sh[8];
  481. __shared__ float dmin_sh[8];
  482. __shared__ unsigned char scales_sh[8][16];
  483. __shared__ unsigned char qs_sh[8][64];
  484. for (int blk = 0; blk < blocksPerRow; blk++) {
  485. // Cache A tile once per block (256 floats) to avoid redundant global loads.
  486. // Each thread loads one element: tid in [0,255].
  487. const int tid = warp * 32 + lane;
  488. const float* aRow = A + row * K + blk * 256;
  489. a_sh[tid] = aRow[tid];
  490. if (colIn) {
  491. const BlockQ2_K* b = &B[col * blocksPerRow + blk];
  492. // Cooperative per-warp cache.
  493. if (lane == 0) {
  494. d_sh[warp] = fp16_to_fp32(b->d);
  495. dmin_sh[warp] = fp16_to_fp32(b->dmin);
  496. }
  497. if (lane < 16) {
  498. scales_sh[warp][lane] = b->scales[lane];
  499. }
  500. // Load 64 bytes qs with 32 lanes.
  501. qs_sh[warp][lane] = b->qs[lane];
  502. qs_sh[warp][lane + 32] = b->qs[lane + 32];
  503. }
  504. __syncthreads();
  505. if (colIn) {
  506. const float d = d_sh[warp];
  507. const float dmin = dmin_sh[warp];
  508. // Each lane handles 8 values.
  509. #pragma unroll
  510. for (int i = 0; i < 8; i++) {
  511. const int idx = lane + (i * 32); // 0..255
  512. const int is = idx >> 5; // 0..7
  513. const int iq = idx & 31; // 0..31
  514. const int qsIdx = (is >> 2) * 32 + iq;
  515. const int shift = (is & 3) * 2;
  516. const int val = (qs_sh[warp][qsIdx] >> shift) & 3;
  517. const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
  518. const unsigned char sc = scales_sh[warp][scIdx];
  519. const float dl = d * (float)(sc & 0xF);
  520. const float ml = dmin * (float)(sc >> 4);
  521. const float w = dl * (float)val - ml;
  522. sum += a_sh[idx] * w;
  523. }
  524. }
  525. // Ensure all warps finished reading this block's a_sh before it is overwritten.
  526. __syncthreads();
  527. }
  528. // Warp reduction
  529. for (int offset = 16; offset > 0; offset >>= 1) {
  530. sum += __shfl_down_sync(0xffffffff, sum, offset);
  531. }
  532. if (colIn && lane == 0) {
  533. C[row * N + col] = sum;
  534. }
  535. }
  536. int cuda_matmul_f32_q2k(float* A, const void* B, float* C, int M, int K, int N) {
  537. int blocksPerRow = K / 256;
  538. dim3 threads(32, 8);
  539. dim3 blocks((N + 7) / 8, M);
  540. matmul_q2k_kernel<<<blocks, threads>>>(A, (const BlockQ2_K*)B, C, M, K, N, blocksPerRow);
  541. CHECK_CUDA(cudaGetLastError());
  542. return 0;
  543. }
  544. // ============================================================
  545. // Fused Q3_K MatMul Kernel
  546. // ============================================================
  547. __global__ void matmul_q3k_kernel(float* A, const BlockQ3_K* B, float* C,
  548. int M, int K, int N, int blocksPerRow) {
  549. const int row = blockIdx.y;
  550. const int warp = threadIdx.y;
  551. const int lane = threadIdx.x;
  552. const int col = blockIdx.x * 8 + warp;
  553. if (row >= M) return;
  554. const bool colIn = (col < N);
  555. float sum = 0.0f;
  556. __shared__ float a_sh[256];
  557. __shared__ float d_sh[8];
  558. __shared__ unsigned char scales_sh[8][12];
  559. __shared__ unsigned char qs_sh[8][64];
  560. __shared__ unsigned char hmask_sh[8][32];
  561. for (int blk = 0; blk < blocksPerRow; blk++) {
  562. // Cache A tile once per block (256 floats). Each thread loads one element.
  563. const int tid = warp * 32 + lane;
  564. const float* aRow = A + row * K + blk * 256;
  565. a_sh[tid] = aRow[tid];
  566. if (colIn) {
  567. const BlockQ3_K* b = &B[col * blocksPerRow + blk];
  568. // Cache quant block bytes.
  569. if (lane == 0) {
  570. d_sh[warp] = fp16_to_fp32(b->d);
  571. }
  572. if (lane < 12) {
  573. scales_sh[warp][lane] = b->scales[lane];
  574. }
  575. // qs: 64 bytes
  576. qs_sh[warp][lane] = b->qs[lane];
  577. qs_sh[warp][lane + 32] = b->qs[lane + 32];
  578. // hmask: 32 bytes
  579. hmask_sh[warp][lane] = b->hmask[lane];
  580. }
  581. __syncthreads();
  582. if (colIn) {
  583. const float d = d_sh[warp];
  584. // Each lane handles 8 elements.
  585. #pragma unroll
  586. for (int i = 0; i < 8; i++) {
  587. const int idx = lane + (i * 32); // 0..255
  588. const int is = idx >> 5; // 0..7
  589. const int iq = idx & 31; // 0..31
  590. const int qsIdx = (is >> 2) * 32 + iq;
  591. const int shift = (is & 3) * 2;
  592. int qv = (qs_sh[warp][qsIdx] >> shift) & 0x3;
  593. const unsigned char m = (unsigned char)(1 << ((is >> 2) * 4 + (is & 3)));
  594. if ((hmask_sh[warp][iq] & m) == 0) {
  595. qv -= 4;
  596. }
  597. const int sIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4); // 0..15
  598. unsigned char sc;
  599. if (sIdx < 8) {
  600. sc = scales_sh[warp][sIdx] & 0xF;
  601. } else {
  602. sc = scales_sh[warp][sIdx - 8] >> 4;
  603. }
  604. sc |= ((scales_sh[warp][8 + (sIdx & 3)] >> (2 * (sIdx >> 2))) & 0x3) << 4;
  605. const float scale = (float)((int)((signed char)sc) - 32);
  606. const float w = d * scale * (float)qv;
  607. sum += a_sh[idx] * w;
  608. }
  609. }
  610. __syncthreads();
  611. }
  612. // Warp reduction
  613. for (int offset = 16; offset > 0; offset >>= 1) {
  614. sum += __shfl_down_sync(0xffffffff, sum, offset);
  615. }
  616. if (colIn && lane == 0) {
  617. C[row * N + col] = sum;
  618. }
  619. }
  620. int cuda_matmul_f32_q3k(float* A, const void* B, float* C, int M, int K, int N) {
  621. int blocksPerRow = K / 256;
  622. dim3 threads(32, 8);
  623. dim3 blocks((N + 7) / 8, M);
  624. matmul_q3k_kernel<<<blocks, threads>>>(A, (const BlockQ3_K*)B, C, M, K, N, blocksPerRow);
  625. CHECK_CUDA(cudaGetLastError());
  626. return 0;
  627. }
  628. // ============================================================
  629. // Fused Q6_K MatMul Kernel
  630. // ============================================================
  631. __global__ void matmul_q6k_kernel(float* A, const BlockQ6_K* B, float* C,
  632. int M, int K, int N, int blocksPerRow) {
  633. const int row = blockIdx.y;
  634. const int warp = threadIdx.y;
  635. const int lane = threadIdx.x;
  636. const int col = blockIdx.x * 8 + warp;
  637. if (row >= M) return;
  638. const bool colIn = (col < N);
  639. float sum = 0.0f;
  640. __shared__ float a_sh[256];
  641. __shared__ float d_sh[8];
  642. __shared__ signed char scales_sh[8][16];
  643. __shared__ unsigned char ql_sh[8][128];
  644. __shared__ unsigned char qh_sh[8][64];
  645. for (int blk = 0; blk < blocksPerRow; blk++) {
  646. // Cache A tile once per block.
  647. const int tid = warp * 32 + lane;
  648. const float* aRow = A + row * K + blk * 256;
  649. a_sh[tid] = aRow[tid];
  650. if (colIn) {
  651. const BlockQ6_K* b = &B[col * blocksPerRow + blk];
  652. // Cache quant block bytes.
  653. if (lane == 0) {
  654. d_sh[warp] = fp16_to_fp32(b->d);
  655. }
  656. if (lane < 16) {
  657. scales_sh[warp][lane] = b->scales[lane];
  658. }
  659. // qh: 64 bytes
  660. qh_sh[warp][lane] = b->qh[lane];
  661. qh_sh[warp][lane + 32] = b->qh[lane + 32];
  662. // ql: 128 bytes
  663. ql_sh[warp][lane] = b->ql[lane];
  664. ql_sh[warp][lane + 32] = b->ql[lane + 32];
  665. ql_sh[warp][lane + 64] = b->ql[lane + 64];
  666. ql_sh[warp][lane + 96] = b->ql[lane + 96];
  667. }
  668. __syncthreads();
  669. if (colIn) {
  670. const float d = d_sh[warp];
  671. // Each lane handles 8 elements.
  672. #pragma unroll
  673. for (int i = 0; i < 8; i++) {
  674. const int idx = lane + (i * 32); // 0..255
  675. const int is = idx >> 5; // 0..7
  676. const int iq = idx & 31; // 0..31
  677. const int qlIdx = (is >> 2) * 64 + (is & 1) * 32 + iq;
  678. const int qhIdx = (is >> 2) * 32 + iq;
  679. const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
  680. const unsigned char ql = ql_sh[warp][qlIdx];
  681. const unsigned char qh = qh_sh[warp][qhIdx];
  682. const int shift_ql = ((is & 3) < 2) ? 0 : 4;
  683. const int shift_qh = (is & 3) * 2;
  684. int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4);
  685. q -= 32;
  686. const float w = d * (float)scales_sh[warp][scIdx] * (float)q;
  687. sum += a_sh[idx] * w;
  688. }
  689. }
  690. __syncthreads();
  691. }
  692. // Warp reduction
  693. for (int offset = 16; offset > 0; offset >>= 1) {
  694. sum += __shfl_down_sync(0xffffffff, sum, offset);
  695. }
  696. if (colIn && lane == 0) {
  697. C[row * N + col] = sum;
  698. }
  699. }
  700. int cuda_matmul_f32_q6k(float* A, const void* B, float* C, int M, int K, int N) {
  701. int blocksPerRow = K / 256;
  702. dim3 threads(32, 8);
  703. dim3 blocks((N + 7) / 8, M);
  704. matmul_q6k_kernel<<<blocks, threads>>>(A, (const BlockQ6_K*)B, C, M, K, N, blocksPerRow);
  705. CHECK_CUDA(cudaGetLastError());
  706. return 0;
  707. }
  708. // ============================================================
  709. // FP16 Input Variants for Q4K, Q5K, Q2K, Q3K, Q6K
  710. // Same logic as FP32 versions but load A as FP16
  711. // ============================================================
  712. __global__ void matmul_q4k_kernel_f16in(const __half* A, const BlockQ4_K* B, float* C,
  713. int M, int K, int N, int blocksPerRow) {
  714. const int row = blockIdx.y;
  715. const int warp = threadIdx.y;
  716. const int lane = threadIdx.x;
  717. const int col = blockIdx.x * 8 + warp;
  718. if (row >= M) return;
  719. const bool colIn = (col < N);
  720. float sum = 0.0f;
  721. __shared__ float a_sh[256];
  722. __shared__ unsigned char sc_sh[8][8];
  723. __shared__ unsigned char m_sh[8][8];
  724. __shared__ float ds_sh[8][8];
  725. __shared__ float dm_sh[8][8];
  726. __shared__ float d_sh[8];
  727. __shared__ float dmin_sh[8];
  728. for (int blk = 0; blk < blocksPerRow; blk++) {
  729. const int tid = warp * 32 + lane;
  730. const __half* aRow = A + row * K + blk * 256;
  731. a_sh[tid] = __half2float(aRow[tid]);
  732. if (colIn) {
  733. const BlockQ4_K* b = &B[col * blocksPerRow + blk];
  734. if (lane < 8) {
  735. unsigned char sc, mn;
  736. if (lane < 4) {
  737. sc = b->scales[lane] & 63;
  738. mn = b->scales[lane + 4] & 63;
  739. } else {
  740. const int j = lane;
  741. sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
  742. mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
  743. }
  744. sc_sh[warp][lane] = sc;
  745. m_sh[warp][lane] = mn;
  746. }
  747. if (lane == 0) {
  748. d_sh[warp] = fp16_to_fp32(b->d);
  749. dmin_sh[warp] = fp16_to_fp32(b->dmin);
  750. }
  751. }
  752. __syncthreads();
  753. if (colIn) {
  754. const BlockQ4_K* b = &B[col * blocksPerRow + blk];
  755. if (lane < 8) {
  756. ds_sh[warp][lane] = d_sh[warp] * (float)sc_sh[warp][lane];
  757. dm_sh[warp][lane] = dmin_sh[warp] * (float)m_sh[warp][lane];
  758. }
  759. __syncwarp();
  760. #pragma unroll
  761. for (int p = 0; p < 4; p++) {
  762. const unsigned char qs = b->qs[p * 32 + lane];
  763. const int q0 = qs & 0xF;
  764. const int q1 = qs >> 4;
  765. const int idx0 = p * 64 + lane;
  766. const int idx1 = idx0 + 32;
  767. const int g0 = 2 * p;
  768. const int g1 = g0 + 1;
  769. sum += a_sh[idx0] * ((float)q0 * ds_sh[warp][g0] - dm_sh[warp][g0]);
  770. sum += a_sh[idx1] * ((float)q1 * ds_sh[warp][g1] - dm_sh[warp][g1]);
  771. }
  772. }
  773. __syncthreads();
  774. }
  775. for (int offset = 16; offset > 0; offset >>= 1) {
  776. sum += __shfl_down_sync(0xffffffff, sum, offset);
  777. }
  778. if (colIn && lane == 0) {
  779. C[row * N + col] = sum;
  780. }
  781. }
  782. int cuda_matmul_f16_q4k(const void* A, const void* B, float* C, int M, int K, int N) {
  783. int blocksPerRow = K / 256;
  784. dim3 threads(32, 8);
  785. dim3 blocks((N + 7) / 8, M);
  786. matmul_q4k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ4_K*)B, C, M, K, N, blocksPerRow);
  787. CHECK_CUDA(cudaGetLastError());
  788. return 0;
  789. }
  790. __global__ void matmul_q5k_kernel_f16in(const __half* A, const BlockQ5_K* B, float* C,
  791. int M, int K, int N, int blocksPerRow) {
  792. const int row = blockIdx.y;
  793. const int warp = threadIdx.y;
  794. const int lane = threadIdx.x;
  795. const int col = blockIdx.x * 8 + warp;
  796. if (row >= M) return;
  797. const bool colIn = (col < N);
  798. float sum = 0.0f;
  799. __shared__ float a_sh[256];
  800. __shared__ unsigned char sc_sh[8][8];
  801. __shared__ unsigned char m_sh[8][8];
  802. __shared__ float ds_sh[8][8];
  803. __shared__ float dm_sh[8][8];
  804. __shared__ float d_sh[8];
  805. __shared__ float dmin_sh[8];
  806. for (int blk = 0; blk < blocksPerRow; blk++) {
  807. const int tid = warp * 32 + lane;
  808. const __half* aRow = A + row * K + blk * 256;
  809. a_sh[tid] = __half2float(aRow[tid]);
  810. if (colIn) {
  811. const BlockQ5_K* b = &B[col * blocksPerRow + blk];
  812. if (lane < 8) {
  813. unsigned char sc, mn;
  814. if (lane < 4) {
  815. sc = b->scales[lane] & 63;
  816. mn = b->scales[lane + 4] & 63;
  817. } else {
  818. const int j = lane;
  819. sc = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
  820. mn = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
  821. }
  822. sc_sh[warp][lane] = sc;
  823. m_sh[warp][lane] = mn;
  824. }
  825. if (lane == 0) {
  826. d_sh[warp] = fp16_to_fp32(b->d);
  827. dmin_sh[warp] = fp16_to_fp32(b->dmin);
  828. }
  829. }
  830. __syncthreads();
  831. if (colIn) {
  832. const BlockQ5_K* b = &B[col * blocksPerRow + blk];
  833. if (lane < 8) {
  834. ds_sh[warp][lane] = d_sh[warp] * (float)sc_sh[warp][lane];
  835. dm_sh[warp][lane] = dmin_sh[warp] * (float)m_sh[warp][lane];
  836. }
  837. __syncwarp();
  838. const unsigned char hb = b->qh[lane];
  839. #pragma unroll
  840. for (int p = 0; p < 4; p++) {
  841. const unsigned char qs = b->qs[p * 32 + lane];
  842. int q0 = qs & 0xF;
  843. int q1 = qs >> 4;
  844. q0 += ((hb >> (2 * p)) & 1) << 4;
  845. q1 += ((hb >> (2 * p + 1)) & 1) << 4;
  846. const int idx0 = p * 64 + lane;
  847. const int idx1 = idx0 + 32;
  848. const int g0 = 2 * p;
  849. const int g1 = g0 + 1;
  850. sum += a_sh[idx0] * ((float)q0 * ds_sh[warp][g0] - dm_sh[warp][g0]);
  851. sum += a_sh[idx1] * ((float)q1 * ds_sh[warp][g1] - dm_sh[warp][g1]);
  852. }
  853. }
  854. __syncthreads();
  855. }
  856. for (int offset = 16; offset > 0; offset >>= 1) {
  857. sum += __shfl_down_sync(0xffffffff, sum, offset);
  858. }
  859. if (colIn && lane == 0) {
  860. C[row * N + col] = sum;
  861. }
  862. }
  863. int cuda_matmul_f16_q5k(const void* A, const void* B, float* C, int M, int K, int N) {
  864. int blocksPerRow = K / 256;
  865. dim3 threads(32, 8);
  866. dim3 blocks((N + 7) / 8, M);
  867. matmul_q5k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ5_K*)B, C, M, K, N, blocksPerRow);
  868. CHECK_CUDA(cudaGetLastError());
  869. return 0;
  870. }
  871. __global__ void matmul_q2k_kernel_f16in(const __half* A, const BlockQ2_K* B, float* C,
  872. int M, int K, int N, int blocksPerRow) {
  873. const int row = blockIdx.y;
  874. const int warp = threadIdx.y;
  875. const int lane = threadIdx.x;
  876. const int col = blockIdx.x * 8 + warp;
  877. if (row >= M) return;
  878. const bool colIn = (col < N);
  879. float sum = 0.0f;
  880. __shared__ float a_sh[256];
  881. __shared__ float d_sh[8];
  882. __shared__ float dmin_sh[8];
  883. __shared__ unsigned char scales_sh[8][16];
  884. __shared__ unsigned char qs_sh[8][64];
  885. for (int blk = 0; blk < blocksPerRow; blk++) {
  886. const int tid = warp * 32 + lane;
  887. const __half* aRow = A + row * K + blk * 256;
  888. a_sh[tid] = __half2float(aRow[tid]);
  889. if (colIn) {
  890. const BlockQ2_K* b = &B[col * blocksPerRow + blk];
  891. if (lane == 0) {
  892. d_sh[warp] = fp16_to_fp32(b->d);
  893. dmin_sh[warp] = fp16_to_fp32(b->dmin);
  894. }
  895. if (lane < 16) {
  896. scales_sh[warp][lane] = b->scales[lane];
  897. }
  898. qs_sh[warp][lane] = b->qs[lane];
  899. qs_sh[warp][lane + 32] = b->qs[lane + 32];
  900. }
  901. __syncthreads();
  902. if (colIn) {
  903. const float d = d_sh[warp];
  904. const float dmin = dmin_sh[warp];
  905. #pragma unroll
  906. for (int i = 0; i < 8; i++) {
  907. const int idx = lane + (i * 32);
  908. const int is = idx >> 5;
  909. const int iq = idx & 31;
  910. const int qsIdx = (is >> 2) * 32 + iq;
  911. const int shift = (is & 3) * 2;
  912. const int val = (qs_sh[warp][qsIdx] >> shift) & 3;
  913. const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
  914. const unsigned char sc = scales_sh[warp][scIdx];
  915. const float dl = d * (float)(sc & 0xF);
  916. const float ml = dmin * (float)(sc >> 4);
  917. sum += a_sh[idx] * (dl * (float)val - ml);
  918. }
  919. }
  920. __syncthreads();
  921. }
  922. for (int offset = 16; offset > 0; offset >>= 1) {
  923. sum += __shfl_down_sync(0xffffffff, sum, offset);
  924. }
  925. if (colIn && lane == 0) {
  926. C[row * N + col] = sum;
  927. }
  928. }
  929. int cuda_matmul_f16_q2k(const void* A, const void* B, float* C, int M, int K, int N) {
  930. int blocksPerRow = K / 256;
  931. dim3 threads(32, 8);
  932. dim3 blocks((N + 7) / 8, M);
  933. matmul_q2k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ2_K*)B, C, M, K, N, blocksPerRow);
  934. CHECK_CUDA(cudaGetLastError());
  935. return 0;
  936. }
  937. __global__ void matmul_q3k_kernel_f16in(const __half* A, const BlockQ3_K* B, float* C,
  938. int M, int K, int N, int blocksPerRow) {
  939. const int row = blockIdx.y;
  940. const int warp = threadIdx.y;
  941. const int lane = threadIdx.x;
  942. const int col = blockIdx.x * 8 + warp;
  943. if (row >= M) return;
  944. const bool colIn = (col < N);
  945. float sum = 0.0f;
  946. __shared__ float a_sh[256];
  947. __shared__ float d_sh[8];
  948. __shared__ unsigned char scales_sh[8][12];
  949. __shared__ unsigned char qs_sh[8][64];
  950. __shared__ unsigned char hmask_sh[8][32];
  951. for (int blk = 0; blk < blocksPerRow; blk++) {
  952. const int tid = warp * 32 + lane;
  953. const __half* aRow = A + row * K + blk * 256;
  954. a_sh[tid] = __half2float(aRow[tid]);
  955. if (colIn) {
  956. const BlockQ3_K* b = &B[col * blocksPerRow + blk];
  957. if (lane == 0) {
  958. d_sh[warp] = fp16_to_fp32(b->d);
  959. }
  960. if (lane < 12) {
  961. scales_sh[warp][lane] = b->scales[lane];
  962. }
  963. qs_sh[warp][lane] = b->qs[lane];
  964. qs_sh[warp][lane + 32] = b->qs[lane + 32];
  965. hmask_sh[warp][lane] = b->hmask[lane];
  966. }
  967. __syncthreads();
  968. if (colIn) {
  969. const float d = d_sh[warp];
  970. #pragma unroll
  971. for (int i = 0; i < 8; i++) {
  972. const int idx = lane + (i * 32);
  973. const int is = idx >> 5;
  974. const int iq = idx & 31;
  975. const int qsIdx = (is >> 2) * 32 + iq;
  976. const int shift = (is & 3) * 2;
  977. int qv = (qs_sh[warp][qsIdx] >> shift) & 0x3;
  978. const unsigned char m = (unsigned char)(1 << ((is >> 2) * 4 + (is & 3)));
  979. if ((hmask_sh[warp][iq] & m) == 0) {
  980. qv -= 4;
  981. }
  982. const int sIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
  983. unsigned char sc;
  984. if (sIdx < 8) {
  985. sc = scales_sh[warp][sIdx] & 0xF;
  986. } else {
  987. sc = scales_sh[warp][sIdx - 8] >> 4;
  988. }
  989. sc |= ((scales_sh[warp][8 + (sIdx & 3)] >> (2 * (sIdx >> 2))) & 0x3) << 4;
  990. const float scale = (float)((int)((signed char)sc) - 32);
  991. sum += a_sh[idx] * (d * scale * (float)qv);
  992. }
  993. }
  994. __syncthreads();
  995. }
  996. for (int offset = 16; offset > 0; offset >>= 1) {
  997. sum += __shfl_down_sync(0xffffffff, sum, offset);
  998. }
  999. if (colIn && lane == 0) {
  1000. C[row * N + col] = sum;
  1001. }
  1002. }
  1003. int cuda_matmul_f16_q3k(const void* A, const void* B, float* C, int M, int K, int N) {
  1004. int blocksPerRow = K / 256;
  1005. dim3 threads(32, 8);
  1006. dim3 blocks((N + 7) / 8, M);
  1007. matmul_q3k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ3_K*)B, C, M, K, N, blocksPerRow);
  1008. CHECK_CUDA(cudaGetLastError());
  1009. return 0;
  1010. }
  1011. __global__ void matmul_q6k_kernel_f16in(const __half* A, const BlockQ6_K* B, float* C,
  1012. int M, int K, int N, int blocksPerRow) {
  1013. const int row = blockIdx.y;
  1014. const int warp = threadIdx.y;
  1015. const int lane = threadIdx.x;
  1016. const int col = blockIdx.x * 8 + warp;
  1017. if (row >= M) return;
  1018. const bool colIn = (col < N);
  1019. float sum = 0.0f;
  1020. __shared__ float a_sh[256];
  1021. __shared__ float d_sh[8];
  1022. __shared__ signed char scales_sh[8][16];
  1023. __shared__ unsigned char ql_sh[8][128];
  1024. __shared__ unsigned char qh_sh[8][64];
  1025. for (int blk = 0; blk < blocksPerRow; blk++) {
  1026. const int tid = warp * 32 + lane;
  1027. const __half* aRow = A + row * K + blk * 256;
  1028. a_sh[tid] = __half2float(aRow[tid]);
  1029. if (colIn) {
  1030. const BlockQ6_K* b = &B[col * blocksPerRow + blk];
  1031. if (lane == 0) {
  1032. d_sh[warp] = fp16_to_fp32(b->d);
  1033. }
  1034. if (lane < 16) {
  1035. scales_sh[warp][lane] = b->scales[lane];
  1036. }
  1037. qh_sh[warp][lane] = b->qh[lane];
  1038. qh_sh[warp][lane + 32] = b->qh[lane + 32];
  1039. ql_sh[warp][lane] = b->ql[lane];
  1040. ql_sh[warp][lane + 32] = b->ql[lane + 32];
  1041. ql_sh[warp][lane + 64] = b->ql[lane + 64];
  1042. ql_sh[warp][lane + 96] = b->ql[lane + 96];
  1043. }
  1044. __syncthreads();
  1045. if (colIn) {
  1046. const float d = d_sh[warp];
  1047. #pragma unroll
  1048. for (int i = 0; i < 8; i++) {
  1049. const int idx = lane + (i * 32);
  1050. const int is = idx >> 5;
  1051. const int iq = idx & 31;
  1052. const int qlIdx = (is >> 2) * 64 + (is & 1) * 32 + iq;
  1053. const int qhIdx = (is >> 2) * 32 + iq;
  1054. const int scIdx = (is >> 2) * 8 + (is & 3) * 2 + (iq >> 4);
  1055. const unsigned char ql = ql_sh[warp][qlIdx];
  1056. const unsigned char qh = qh_sh[warp][qhIdx];
  1057. const int shift_ql = ((is & 3) < 2) ? 0 : 4;
  1058. const int shift_qh = (is & 3) * 2;
  1059. int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4);
  1060. q -= 32;
  1061. sum += a_sh[idx] * (d * (float)scales_sh[warp][scIdx] * (float)q);
  1062. }
  1063. }
  1064. __syncthreads();
  1065. }
  1066. for (int offset = 16; offset > 0; offset >>= 1) {
  1067. sum += __shfl_down_sync(0xffffffff, sum, offset);
  1068. }
  1069. if (colIn && lane == 0) {
  1070. C[row * N + col] = sum;
  1071. }
  1072. }
  1073. int cuda_matmul_f16_q6k(const void* A, const void* B, float* C, int M, int K, int N) {
  1074. int blocksPerRow = K / 256;
  1075. dim3 threads(32, 8);
  1076. dim3 blocks((N + 7) / 8, M);
  1077. matmul_q6k_kernel_f16in<<<blocks, threads>>>((const __half*)A, (const BlockQ6_K*)B, C, M, K, N, blocksPerRow);
  1078. CHECK_CUDA(cudaGetLastError());
  1079. return 0;
  1080. }