| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- #include "cuda_common.cuh"
- // ============================================================
- // Q4_K Dequantization Kernel
- // 256 elements per block, 128 bytes of qs (2 elements per byte)
- // Complex scale unpacking
- // ============================================================
- __global__ void dequant_q4k_kernel(const BlockQ4_K* blocks, float* out, int numBlocks) {
- int blockIdx_q = blockIdx.x;
- int elemIdx = threadIdx.x; // 0-255
-
- if (blockIdx_q >= numBlocks) return;
-
- const BlockQ4_K* b = &blocks[blockIdx_q];
- float d = fp16_to_fp32(b->d);
- float dmin = fp16_to_fp32(b->dmin);
-
- // Unpack scales and mins
- unsigned char sc[8], m[8];
- #pragma unroll
- for (int j = 0; j < 4; j++) {
- sc[j] = b->scales[j] & 63;
- m[j] = b->scales[j + 4] & 63;
- }
- #pragma unroll
- for (int j = 4; j < 8; j++) {
- sc[j] = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
- m[j] = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
- }
-
- // Which sub-block and position
- int subBlock = elemIdx / 32;
- int subPos = elemIdx % 32;
- int qsIdx = (subBlock / 2) * 32 + subPos;
-
- unsigned char qs = b->qs[qsIdx];
- int val = (subBlock % 2 == 0) ? (qs & 0xF) : (qs >> 4);
-
- float scale = d * (float)sc[subBlock];
- float minVal = dmin * (float)m[subBlock];
-
- out[blockIdx_q * 256 + elemIdx] = (float)val * scale - minVal;
- }
- int cuda_dequant_q4k(const void* blocks, float* out, int numBlocks) {
- if (numBlocks == 0) return 0;
- dequant_q4k_kernel<<<numBlocks, 256>>>((const BlockQ4_K*)blocks, out, numBlocks);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
|