cuda_dequant_q5k.cu 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #include "cuda_common.cuh"
  2. __global__ void dequant_q5k_kernel(const BlockQ5_K* blocks, float* out, int numBlocks) {
  3. int blockIdx_q = blockIdx.x;
  4. int elemIdx = threadIdx.x;
  5. if (blockIdx_q >= numBlocks) return;
  6. const BlockQ5_K* b = &blocks[blockIdx_q];
  7. float d = fp16_to_fp32(b->d);
  8. float dmin = fp16_to_fp32(b->dmin);
  9. unsigned char sc[8], m[8];
  10. #pragma unroll
  11. for (int j = 0; j < 4; j++) {
  12. sc[j] = b->scales[j] & 63;
  13. m[j] = b->scales[j + 4] & 63;
  14. }
  15. #pragma unroll
  16. for (int j = 4; j < 8; j++) {
  17. sc[j] = (b->scales[j + 4] & 0xF) | ((b->scales[j - 4] >> 6) << 4);
  18. m[j] = (b->scales[j + 4] >> 4) | ((b->scales[j] >> 6) << 4);
  19. }
  20. int subBlock = elemIdx / 32;
  21. int subPos = elemIdx % 32;
  22. int chunk = elemIdx / 64;
  23. int posInChunk = elemIdx % 64;
  24. int qsIdx = chunk * 32 + (posInChunk & 31);
  25. unsigned char qs = b->qs[qsIdx];
  26. int val;
  27. unsigned char hb = b->qh[posInChunk & 31];
  28. if (posInChunk < 32) {
  29. val = (qs & 0xF);
  30. val += ((hb >> (2 * chunk)) & 1) << 4;
  31. } else {
  32. val = (qs >> 4);
  33. val += ((hb >> (2 * chunk + 1)) & 1) << 4;
  34. }
  35. float scale = d * (float)sc[subBlock];
  36. float minVal = dmin * (float)m[subBlock];
  37. out[blockIdx_q * 256 + elemIdx] = (float)val * scale - minVal;
  38. }
  39. int cuda_dequant_q5k(const void* blocks, float* out, int numBlocks) {
  40. if (numBlocks == 0) return 0;
  41. dequant_q5k_kernel<<<numBlocks, 256>>>((const BlockQ5_K*)blocks, out, numBlocks);
  42. CHECK_CUDA(cudaGetLastError());
  43. return 0;
  44. }