cuda_dequant_other.cu 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. #include "cuda_common.cuh"
  2. // ============================================================
  3. // Q6_K Dequantization Kernel
  4. // 256 elements, lower 4 bits + upper 2 bits
  5. // ============================================================
  6. __global__ void dequant_q6k_kernel(const BlockQ6_K* blocks, float* out, int numBlocks) {
  7. int blockIdx_q = blockIdx.x;
  8. int elemIdx = threadIdx.x; // 0-255
  9. if (blockIdx_q >= numBlocks) return;
  10. const BlockQ6_K* b = &blocks[blockIdx_q];
  11. float d = fp16_to_fp32(b->d);
  12. // Position within 128-element halves
  13. int half = elemIdx / 128;
  14. int pos = elemIdx % 128;
  15. const int is = elemIdx / 32;
  16. const int iq = elemIdx % 32;
  17. int qlIdx = (is / 4) * 64 + (is % 2) * 32 + iq;
  18. int qhIdx = (is / 4) * 32 + iq;
  19. int scIdx = (is / 4) * 8 + (is % 4) * 2 + (iq / 16);
  20. unsigned char ql = b->ql[qlIdx];
  21. unsigned char qh = b->qh[qhIdx];
  22. int shift_ql = ((is % 4) < 2) ? 0 : 4;
  23. int shift_qh = (is % 4) * 2;
  24. int q = ((ql >> shift_ql) & 0xF) | (((qh >> shift_qh) & 3) << 4);
  25. q -= 32;
  26. out[blockIdx_q * 256 + elemIdx] = d * (float)b->scales[scIdx] * (float)q;
  27. }
  28. int cuda_dequant_q6k(const void* blocks, float* out, int numBlocks) {
  29. if (numBlocks == 0) return 0;
  30. dequant_q6k_kernel<<<numBlocks, 256>>>((const BlockQ6_K*)blocks, out, numBlocks);
  31. CHECK_CUDA(cudaGetLastError());
  32. return 0;
  33. }
  34. // ============================================================
  35. // Q3_K Dequantization Kernel
  36. // ============================================================
  37. __device__ __forceinline__ signed char unpack_q3_scale(const unsigned char* packed, int idx) {
  38. unsigned char sc;
  39. if (idx < 8) {
  40. sc = packed[idx] & 0xF;
  41. } else {
  42. sc = packed[idx - 8] >> 4;
  43. }
  44. sc |= ((packed[8 + (idx % 4)] >> (2 * (idx / 4))) & 0x3) << 4;
  45. return (signed char)sc - 32;
  46. }
  47. __global__ void dequant_q3k_kernel(const BlockQ3_K* blocks, float* out, int numBlocks) {
  48. int blockIdx_q = blockIdx.x;
  49. int elemIdx = threadIdx.x;
  50. if (blockIdx_q >= numBlocks) return;
  51. const BlockQ3_K* b = &blocks[blockIdx_q];
  52. float d = fp16_to_fp32(b->d);
  53. const int is = elemIdx / 32;
  54. const int iq = elemIdx % 32;
  55. int qsIdx = (is / 4) * 32 + iq;
  56. int hmaskIdx = iq;
  57. int scaleIdx = (is / 4) * 8 + (is % 4) * 2 + (iq / 16);
  58. int shift = (is % 4) * 2;
  59. unsigned char m = 1 << ((is / 4) * 4 + (is % 4));
  60. signed char scale = unpack_q3_scale(b->scales, scaleIdx);
  61. int qv = (b->qs[qsIdx] >> shift) & 0x3;
  62. if ((b->hmask[hmaskIdx] & m) == 0) {
  63. qv -= 4;
  64. }
  65. out[blockIdx_q * 256 + elemIdx] = d * (float)scale * (float)qv;
  66. }
  67. int cuda_dequant_q3k(const void* blocks, float* out, int numBlocks) {
  68. if (numBlocks == 0) return 0;
  69. dequant_q3k_kernel<<<numBlocks, 256>>>((const BlockQ3_K*)blocks, out, numBlocks);
  70. CHECK_CUDA(cudaGetLastError());
  71. return 0;
  72. }
  73. // ============================================================
  74. // Q2_K Dequantization Kernel
  75. // ============================================================
  76. __global__ void dequant_q2k_kernel(const BlockQ2_K* blocks, float* out, int numBlocks) {
  77. int blockIdx_q = blockIdx.x;
  78. int elemIdx = threadIdx.x;
  79. if (blockIdx_q >= numBlocks) return;
  80. const BlockQ2_K* b = &blocks[blockIdx_q];
  81. float d = fp16_to_fp32(b->d);
  82. float dmin = fp16_to_fp32(b->dmin);
  83. const int is = elemIdx / 32;
  84. const int iq = elemIdx % 32;
  85. int scIdx = (is / 4) * 8 + (is % 4) * 2 + (iq / 16);
  86. int qsIdx = (is / 4) * 32 + iq;
  87. int shift = (is % 4) * 2;
  88. unsigned char sc = b->scales[scIdx];
  89. float dl = d * (float)(sc & 0xF);
  90. float ml = dmin * (float)(sc >> 4);
  91. int val = (b->qs[qsIdx] >> shift) & 3;
  92. out[blockIdx_q * 256 + elemIdx] = dl * (float)val - ml;
  93. }
  94. int cuda_dequant_q2k(const void* blocks, float* out, int numBlocks) {
  95. if (numBlocks == 0) return 0;
  96. dequant_q2k_kernel<<<numBlocks, 256>>>((const BlockQ2_K*)blocks, out, numBlocks);
  97. CHECK_CUDA(cudaGetLastError());
  98. return 0;
  99. }