|
|
@@ -155,25 +155,27 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
|
#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8}
|
|
|
|
|
|
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
|
|
|
- return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
|
|
|
- type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
|
|
|
- type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
- type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
|
|
|
- type == GGML_TYPE_Q8_0 ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
- type == GGML_TYPE_Q2_K ? MMQ_DP4A_TXS_Q2_K :
|
|
|
- type == GGML_TYPE_Q3_K ? MMQ_DP4A_TXS_Q3_K :
|
|
|
- type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K :
|
|
|
- type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K :
|
|
|
- type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K :
|
|
|
- type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 :
|
|
|
- type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
|
|
|
- type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
|
|
|
- tile_x_sizes{0, 0, 0};
|
|
|
+ switch (type) {
|
|
|
+ case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0;
|
|
|
+ case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1;
|
|
|
+ case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
|
|
|
+ case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
|
|
|
+ case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
|
|
|
+ case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
|
|
|
+ case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
|
|
|
+ case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
|
|
|
+ case GGML_TYPE_Q5_K: return MMQ_DP4A_TXS_Q5_K;
|
|
|
+ case GGML_TYPE_Q6_K: return MMQ_DP4A_TXS_Q6_K;
|
|
|
+ case GGML_TYPE_IQ2_XXS: return MMQ_DP4A_TXS_Q8_0;
|
|
|
+ case GGML_TYPE_IQ2_XS: return MMQ_DP4A_TXS_Q8_0_16;
|
|
|
+ case GGML_TYPE_IQ2_S: return MMQ_DP4A_TXS_Q8_0_16;
|
|
|
+ case GGML_TYPE_IQ3_XXS: return MMQ_DP4A_TXS_Q8_0;
|
|
|
+ case GGML_TYPE_IQ3_S: return MMQ_DP4A_TXS_Q8_0;
|
|
|
+ case GGML_TYPE_IQ1_S: return MMQ_DP4A_TXS_Q8_0;
|
|
|
+ case GGML_TYPE_IQ4_XS: return MMQ_DP4A_TXS_Q8_0;
|
|
|
+ case GGML_TYPE_IQ4_NL: return MMQ_DP4A_TXS_Q8_0;
|
|
|
+ default: return tile_x_sizes{0, 0, 0};
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4)
|
|
|
@@ -189,25 +191,27 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
|
|
|
static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
|
|
|
|
|
|
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
|
|
|
- return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
- type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
|
|
|
- type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
- type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
|
|
|
- type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
- type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K :
|
|
|
- type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K :
|
|
|
- type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 :
|
|
|
- type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 :
|
|
|
- type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K :
|
|
|
- type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K :
|
|
|
- type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
|
|
|
- type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
- type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
|
|
|
- 0;
|
|
|
+ switch (type) {
|
|
|
+ case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
|
+ case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
|
+ case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
|
|
|
+ case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
|
|
|
+ case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
|
+ case GGML_TYPE_Q5_K: return MMQ_MMA_TILE_X_K_Q8_1;
|
|
|
+ case GGML_TYPE_Q6_K: return MMQ_MMA_TILE_X_K_Q6_K;
|
|
|
+ case GGML_TYPE_IQ2_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ case GGML_TYPE_IQ2_XS: return MMQ_MMA_TILE_X_K_Q3_K;
|
|
|
+ case GGML_TYPE_IQ2_S: return MMQ_MMA_TILE_X_K_Q3_K;
|
|
|
+ case GGML_TYPE_IQ3_XXS: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ case GGML_TYPE_IQ3_S: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ case GGML_TYPE_IQ1_S: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ case GGML_TYPE_IQ4_XS: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ case GGML_TYPE_IQ4_NL: return MMQ_MMA_TILE_X_K_Q8_0;
|
|
|
+ default: return 0;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1)
|