|
|
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
|
|
|
|
|
|
#define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
|
|
|
#define MMQ_ITER_K 256
|
|
|
+#define MMQ_ITER_K_MXFP4_FP4 512
|
|
|
#define MMQ_NWARPS 8
|
|
|
|
|
|
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
|
|
|
@@ -44,8 +45,15 @@ struct block_q8_1_mmq {
|
|
|
};
|
|
|
int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
|
|
|
};
|
|
|
+
|
|
|
+struct block_fp4_mmq {
|
|
|
+ uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
|
|
|
+ int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
|
|
|
+};
|
|
|
+
|
|
|
static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
|
|
|
static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1), "Unexpected block_q8_1_mmq size");
|
|
|
+static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected block_fp4_mmq size");
|
|
|
|
|
|
static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
|
|
|
switch (type_x) {
|
|
|
@@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
|
|
|
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
|
|
|
}
|
|
|
|
|
|
+static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
|
|
|
+#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
|
+ return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
|
|
|
+#else
|
|
|
+ return MMQ_ITER_K;
|
|
|
+#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
|
+}
|
|
|
+
|
|
|
static constexpr __device__ int get_mmq_y_device() {
|
|
|
#if defined(GGML_USE_HIP)
|
|
|
#if defined(RDNA1)
|
|
|
@@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
|
|
|
}
|
|
|
|
|
|
#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
|
+#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4)
|
|
|
#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4)
|
|
|
#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4)
|
|
|
#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4)
|
|
|
@@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
|
|
|
static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
|
|
|
static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
|
|
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
|
|
+static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding.");
|
|
|
+static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
|
|
|
|
|
|
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
switch (type) {
|
|
|
@@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
|
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ // tile sizes are the same for Q8_1 and FP4 for blackwell
|
|
|
case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
|
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
|
|
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
|
|
@@ -228,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
}
|
|
|
|
|
|
// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
|
|
|
-#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
|
|
|
+#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
|
|
|
+#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
|
|
|
|
|
|
static int mmq_get_granularity_host(const int mmq_x, const int cc) {
|
|
|
if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
|
|
|
@@ -761,6 +782,50 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+template <int mmq_y, bool need_check>
|
|
|
+static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
|
|
|
+ int * __restrict__ x_tile,
|
|
|
+ const int kbx0,
|
|
|
+ const int i_max,
|
|
|
+ const int stride) {
|
|
|
+ constexpr int nwarps = mmq_get_nwarps_device();
|
|
|
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
+
|
|
|
+ int * x_qs = (int *) x_tile;
|
|
|
+ uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
|
+
|
|
|
+ const int txi = threadIdx.x;
|
|
|
+
|
|
|
+ constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
|
|
|
+
|
|
|
+ constexpr int threads_per_row = iter_k / QK_MXFP4; // each thread processes 1 block
|
|
|
+ constexpr int rows_per_warp = warp_size / threads_per_row;
|
|
|
+ const int kbx = txi % threads_per_row;
|
|
|
+ const int row_in_warp = txi / threads_per_row;
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
|
|
|
+ int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
|
|
|
+
|
|
|
+ if constexpr (need_check) {
|
|
|
+ i = min(i, i_max);
|
|
|
+ }
|
|
|
+
|
|
|
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
|
|
|
+
|
|
|
+ // quantize_mxfp4_mmq permutes nibbles to match the quantized format
|
|
|
+ const int k0 = kbx * 4;
|
|
|
+ memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
|
|
|
+
|
|
|
+ // Load E8M0 scales: pack 2 consecutive scales into one uint32
|
|
|
+ if (kbx % 2 == 0) {
|
|
|
+ uint32_t e = bxi->e;
|
|
|
+ e |= ((bxi + 1)->e << 8);
|
|
|
+ x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
template <int mmq_x, int mmq_y>
|
|
|
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
|
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
|
@@ -931,6 +996,78 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
|
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
|
}
|
|
|
|
|
|
+template <int mmq_x, int mmq_y>
|
|
|
+static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
|
|
|
+ const int * __restrict__ y,
|
|
|
+ float * __restrict__ sum,
|
|
|
+ const int k00) {
|
|
|
+ typedef tile<16, 8, int> tile_A;
|
|
|
+ typedef tile<8, 8, int> tile_B;
|
|
|
+ typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA
|
|
|
+
|
|
|
+ constexpr int granularity = mmq_get_granularity_device(mmq_x);
|
|
|
+ constexpr int rows_per_warp = 2 * granularity;
|
|
|
+ constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp.
|
|
|
+
|
|
|
+ y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
|
|
|
+
|
|
|
+ // Match layout from load_tiles_mxfp4_fp4
|
|
|
+ const int * x_qs = (const int *) x;
|
|
|
+ const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
|
|
|
+ const int * y_qs = (const int *) y + 4;
|
|
|
+ const uint32_t * y_sc = (const uint32_t *) y;
|
|
|
+
|
|
|
+ // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
|
|
|
+ tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
|
+ uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
|
|
|
+
|
|
|
+ // Block scale
|
|
|
+ // Each thread has to point to a 4 byte scale value
|
|
|
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
|
|
|
+
|
|
|
+ const int i0 = (threadIdx.y / ntx) * rows_per_warp;
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int n = 0; n < ntx; ++n) {
|
|
|
+#pragma unroll
|
|
|
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
|
+ const int k0 = k00 + k01;
|
|
|
+
|
|
|
+ load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
|
|
|
+ MMQ_MMA_TILE_X_K_FP4);
|
|
|
+
|
|
|
+ // based on block-scaling document, 2 threads in each quad need to supply to the scale value
|
|
|
+ const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
|
|
|
+ scaleA[n][k01 / (2 * QI_MXFP4)] =
|
|
|
+ *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
|
|
|
+#pragma unroll
|
|
|
+ for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
|
|
|
+ tile_B B;
|
|
|
+ uint32_t scaleB; // 2xN scales
|
|
|
+
|
|
|
+ load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
|
|
|
+
|
|
|
+ scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int n = 0; n < ntx; ++n) {
|
|
|
+ tile_C C;
|
|
|
+
|
|
|
+ mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
|
|
|
+#pragma unroll
|
|
|
+ for (int l = 0; l < tile_C::ne; ++l) {
|
|
|
+ sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
template <int mmq_x, int mmq_y>
|
|
|
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
|
|
|
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
|
|
|
@@ -3109,8 +3246,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
|
|
|
template <int mmq_x, int mmq_y, bool need_check>
|
|
|
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
|
|
|
static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
|
|
|
+#ifdef BLACKWELL_MMA_AVAILABLE
|
|
|
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>;
|
|
|
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
|
|
|
+#else
|
|
|
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
|
|
|
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
|
|
|
+#endif // BLACKWELL_MMA_AVAILABLE
|
|
|
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
|
|
|
};
|
|
|
|
|
|
@@ -3243,17 +3385,26 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
|
|
|
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
|
|
|
|
|
|
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
|
|
+#if defined(BLACKWELL_MMA_AVAILABLE)
|
|
|
+ // FP4 tile stores 8 blocks
|
|
|
+ constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
|
|
|
+#else
|
|
|
+ constexpr int ne_block = 4 * QK8_1;
|
|
|
+#endif // defined(BLACKWELL_MMA_AVAILABLE)
|
|
|
+
|
|
|
+ constexpr int ITER_K = get_iter_k(type);
|
|
|
+ constexpr int blocks_per_iter = ITER_K / qk;
|
|
|
|
|
|
float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
|
|
|
|
|
|
+ constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
|
|
|
+
|
|
|
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
|
|
|
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
|
|
|
-
|
|
|
{
|
|
|
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
|
|
|
+ const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
|
|
|
#pragma unroll
|
|
|
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
|
|
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
|
|
|
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
|
|
|
|
|
tile_y[l] = by0[l];
|
|
|
@@ -3267,9 +3418,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
|
|
|
__syncthreads();
|
|
|
|
|
|
{
|
|
|
- const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
|
|
|
+ const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
|
|
|
#pragma unroll
|
|
|
- for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
|
|
|
+ for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
|
|
|
int l = l0 + threadIdx.y*warp_size + threadIdx.x;
|
|
|
|
|
|
tile_y[l] = by0[l];
|
|
|
@@ -3401,8 +3552,10 @@ static __global__ void mul_mat_q(
|
|
|
}
|
|
|
#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
|
|
|
|
|
|
+ constexpr int ITER_K = get_iter_k(type);
|
|
|
+
|
|
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
|
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
|
|
+ constexpr int blocks_per_iter = ITER_K / qk;
|
|
|
|
|
|
// kbc == k block continuous, current index in continuous ijk space.
|
|
|
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
|
|
|
@@ -3463,7 +3616,7 @@ static __global__ void mul_mat_q(
|
|
|
__syncthreads();
|
|
|
}
|
|
|
|
|
|
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
|
|
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
|
|
|
offset_dst += it*mmq_y;
|
|
|
|
|
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
|
@@ -3530,7 +3683,7 @@ static __global__ void mul_mat_q(
|
|
|
__syncthreads();
|
|
|
}
|
|
|
|
|
|
- offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
|
|
|
+ offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
|
|
|
offset_dst += it*mmq_y;
|
|
|
|
|
|
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
|
|
|
@@ -3553,7 +3706,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
|
|
|
const int ncols_max) {
|
|
|
constexpr int mmq_y = get_mmq_y_device();
|
|
|
constexpr int qk = ggml_cuda_type_traits<type>::qk;
|
|
|
- constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
|
|
+ constexpr int ITER_K = get_iter_k(type);
|
|
|
+
|
|
|
+ constexpr int blocks_per_iter = ITER_K / qk;
|
|
|
const int64_t blocks_per_ne00 = ncols_x / qk;
|
|
|
|
|
|
constexpr int nwarps = mmq_get_nwarps_device();
|
|
|
@@ -3711,7 +3866,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
|
|
|
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
|
|
|
const size_t nbs_ids = mmq_x*sizeof(int);
|
|
|
const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
|
|
|
- const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
|
|
|
+ const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
|
|
|
return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
|
|
|
}
|
|
|
|