cuda_dequant_q4k.cu 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. #include "cuda_common.cuh"
  2. // ============================================================
  3. // Q4_K Dequantization Kernel
  4. // 256 elements per block, 128 bytes of qs (2 elements per byte)
  5. // Complex scale unpacking
  6. // ============================================================
  7. __global__ void dequant_q4k_kernel(const BlockQ4_K* blocks, float* out, int numBlocks) {
  8. int blockIdx_q = blockIdx.x;
  9. int elemIdx = threadIdx.x; // 0-255
  10. if (blockIdx_q >= numBlocks) return;
  11. const BlockQ4_K* b = &blocks[blockIdx_q];
  12. float d = fp16_to_fp32(b->d);
  13. float dmin = fp16_to_fp32(b->dmin);
  14. // Unpack scales and mins
  15. unsigned char sc[8], m[8];
  16. #pragma unroll
  17. for (int j = 0; j < 4; j++) {
  18. sc[j] = b->scales[j] & 63;
  19. m[j] = b->scales[j + 4] & 63;
  20. }
  21. #pragma unroll
  22. for (int j = 4; j < 8; j++) {
  23. sc[j] = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
  24. m[j] = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
  25. }
  26. // Which sub-block and position
  27. int subBlock = elemIdx / 32;
  28. int subPos = elemIdx % 32;
  29. int qsIdx = (subBlock / 2) * 32 + subPos;
  30. unsigned char qs = b->qs[qsIdx];
  31. int val = (subBlock % 2 == 0) ? (qs & 0xF) : (qs >> 4);
  32. float scale = d * (float)sc[subBlock];
  33. float minVal = dmin * (float)m[subBlock];
  34. out[blockIdx_q * 256 + elemIdx] = (float)val * scale - minVal;
  35. }
  36. int cuda_dequant_q4k(const void* blocks, float* out, int numBlocks) {
  37. if (numBlocks == 0) return 0;
  38. dequant_q4k_kernel<<<numBlocks, 256>>>((const BlockQ4_K*)blocks, out, numBlocks);
  39. CHECK_CUDA(cudaGetLastError());
  40. return 0;
  41. }