cuda_dequant_q8k.cu 828 B

123456789101112131415161718192021222324
  1. #include "cuda_common.cuh"
  2. // ============================================================
  3. // Q8_K Dequantization Kernel
  4. // Each thread handles 1 element within a block
  5. // ============================================================
  6. __global__ void dequant_q8k_kernel(const BlockQ8_K* blocks, float* out, int numBlocks) {
  7. int blockIdx_q = blockIdx.x;
  8. int elemIdx = threadIdx.x; // 0-255
  9. if (blockIdx_q >= numBlocks) return;
  10. const BlockQ8_K* b = &blocks[blockIdx_q];
  11. float d = b->d;
  12. out[blockIdx_q * 256 + elemIdx] = d * (float)b->qs[elemIdx];
  13. }
  14. int cuda_dequant_q8k(const void* blocks, float* out, int numBlocks) {
  15. if (numBlocks == 0) return 0;
  16. dequant_q8k_kernel<<<numBlocks, 256>>>((const BlockQ8_K*)blocks, out, numBlocks);
  17. CHECK_CUDA(cudaGetLastError());
  18. return 0;
  19. }