| 123456789101112131415161718192021222324 |
- #include "cuda_common.cuh"
- // ============================================================
- // Q8_K Dequantization Kernel
- // Each thread handles 1 element within a block
- // ============================================================
- __global__ void dequant_q8k_kernel(const BlockQ8_K* blocks, float* out, int numBlocks) {
- int blockIdx_q = blockIdx.x;
- int elemIdx = threadIdx.x; // 0-255
-
- if (blockIdx_q >= numBlocks) return;
-
- const BlockQ8_K* b = &blocks[blockIdx_q];
- float d = b->d;
-
- out[blockIdx_q * 256 + elemIdx] = d * (float)b->qs[elemIdx];
- }
- int cuda_dequant_q8k(const void* blocks, float* out, int numBlocks) {
- if (numBlocks == 0) return 0;
- dequant_q8k_kernel<<<numBlocks, 256>>>((const BlockQ8_K*)blocks, out, numBlocks);
- CHECK_CUDA(cudaGetLastError());
- return 0;
- }
|