|
@@ -486,6 +486,7 @@ struct vk_device_struct {
|
|
|
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
|
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
|
|
|
|
|
|
|
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
|
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
|
|
|
|
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
|
|
|
|
|
|
|
|
vk_pipeline pipeline_matmul_split_k_reduce;
|
|
vk_pipeline pipeline_matmul_split_k_reduce;
|
|
|
vk_pipeline pipeline_quantize_q8_1;
|
|
vk_pipeline pipeline_quantize_q8_1;
|
|
@@ -2448,8 +2449,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
l_warptile_id, m_warptile_id, s_warptile_id,
|
|
l_warptile_id, m_warptile_id, s_warptile_id,
|
|
|
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
|
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
|
|
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
|
|
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
|
|
|
|
|
+ l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
|
|
|
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
|
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
|
|
- l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
|
|
|
|
|
|
|
+ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
|
|
|
|
|
+ l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
|
|
|
|
|
+ l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
|
|
|
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
|
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
|
|
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
|
|
l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
|
|
|
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
|
|
l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
|
|
@@ -2512,10 +2516,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
|
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
|
|
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
|
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
|
|
|
|
|
|
|
|
|
+ // Integer MMQ has a smaller shared memory profile, but heavier register use
|
|
|
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
|
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
|
|
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
|
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
|
|
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
|
|
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
|
|
|
|
|
|
|
|
|
|
+ // K-quants use even more registers, mitigate by setting WMITER to 1
|
|
|
|
|
+ l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
|
|
|
|
|
+ m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
|
|
|
|
|
+ s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
|
|
|
|
|
+
|
|
|
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
|
|
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
|
|
|
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
|
|
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
|
|
|
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
|
|
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
|
|
@@ -2524,10 +2534,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
|
|
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
|
|
|
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
|
|
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
|
|
|
|
|
|
|
|
|
|
+ l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
|
|
|
|
|
+ m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
|
|
|
|
|
+ s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
|
|
|
|
|
+
|
|
|
|
|
+ l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
|
|
|
|
|
+ m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
|
|
|
|
|
+ s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
|
|
|
|
|
+
|
|
|
// chip specific tuning
|
|
// chip specific tuning
|
|
|
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
|
|
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
|
|
|
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
|
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
|
|
- m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
|
|
|
|
|
|
+ m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
|
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
|
@@ -2912,18 +2930,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
|
|
|
|
|
|
-#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
|
|
|
|
|
+#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
|
|
|
if (device->mul_mat ## ID ## _l[TYPE]) { \
|
|
if (device->mul_mat ## ID ## _l[TYPE]) { \
|
|
|
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
|
|
|
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
|
} \
|
|
} \
|
|
|
if (device->mul_mat ## ID ## _m[TYPE]) { \
|
|
if (device->mul_mat ## ID ## _m[TYPE]) { \
|
|
|
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
|
|
|
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
|
} \
|
|
} \
|
|
|
if (device->mul_mat ## ID ## _s[TYPE]) { \
|
|
if (device->mul_mat ## ID ## _s[TYPE]) { \
|
|
|
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
|
|
|
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
|
|
|
} \
|
|
} \
|
|
|
|
|
|
|
|
// Create 2 variants, {f16,f32} accumulator
|
|
// Create 2 variants, {f16,f32} accumulator
|
|
@@ -2962,11 +2977,19 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
|
|
|
|
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
|
if (device->integer_dot_product) {
|
|
if (device->integer_dot_product) {
|
|
|
- CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
- CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
- CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
- CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
- CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
|
|
|
}
|
|
}
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
@@ -2996,6 +3019,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
|
|
|
+
|
|
|
|
|
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
|
|
|
+ if (device->integer_dot_product) {
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
|
|
|
+
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size);
|
|
|
|
|
+
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16);
|
|
|
|
|
+ }
|
|
|
|
|
+#endif
|
|
|
} else {
|
|
} else {
|
|
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0);
|
|
@@ -3022,6 +3063,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+
|
|
|
|
|
+#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
|
|
|
+ if (device->integer_dot_product) {
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, 4, _id, 0);
|
|
|
|
|
+ }
|
|
|
|
|
+#endif
|
|
|
}
|
|
}
|
|
|
#undef CREATE_MM2
|
|
#undef CREATE_MM2
|
|
|
#undef CREATE_MMQ
|
|
#undef CREATE_MMQ
|
|
@@ -3086,6 +3145,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
|
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
|
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
+
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
|
|
|
|
+ CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
|
|
|
}
|
|
}
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
@@ -3145,7 +3210,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
}
|
|
}
|
|
|
// reusing CREATE_MM from the fp32 path
|
|
// reusing CREATE_MM from the fp32 path
|
|
|
if ((device->coopmat2 || device->coopmat_support)
|
|
if ((device->coopmat2 || device->coopmat_support)
|
|
|
-#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
|
|
|
|
|
+#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
|
&& !device->coopmat_bf16_support
|
|
&& !device->coopmat_bf16_support
|
|
|
#endif
|
|
#endif
|
|
|
) {
|
|
) {
|
|
@@ -4928,7 +4993,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
|
|
|
|
|
// MMQ
|
|
// MMQ
|
|
|
if (src1_type == GGML_TYPE_Q8_1) {
|
|
if (src1_type == GGML_TYPE_Q8_1) {
|
|
|
- vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
|
|
|
|
|
|
|
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
|
|
|
|
|
|
|
|
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
|
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
|
|
return nullptr;
|
|
return nullptr;
|
|
@@ -5075,6 +5140,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ // MMQ
|
|
|
|
|
+ if (src1_type == GGML_TYPE_Q8_1) {
|
|
|
|
|
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
|
|
|
|
|
+
|
|
|
|
|
+ if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
|
|
|
|
+ return nullptr;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return pipelines;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
|
|
GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
|
|
|
|
|
|
|
|
switch (src0_type) {
|
|
switch (src0_type) {
|
|
@@ -6877,10 +6953,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
|
|
|
|
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
|
|
|
|
|
|
|
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
|
|
|
|
|
|
+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
|
|
|
|
+
|
|
|
|
|
+ // Check for mmq first
|
|
|
|
|
+ vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
|
|
|
|
|
+
|
|
|
|
|
+ if (mmp == nullptr) {
|
|
|
|
|
+ // Fall back to f16 dequant mul mat
|
|
|
|
|
+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
|
|
|
|
+ quantize_y = false;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
|
|
- const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
|
|
|
|
|
|
|
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
|
|
|
|
|
|
|
|
if (qx_needs_dequant) {
|
|
if (qx_needs_dequant) {
|
|
|
// Fall back to dequant + f16 mulmat
|
|
// Fall back to dequant + f16 mulmat
|
|
@@ -6890,8 +6975,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
// Not implemented
|
|
// Not implemented
|
|
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
|
|
|
|
|
|
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
|
|
|
|
|
- const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
|
|
|
|
|
|
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
|
|
|
|
|
+ const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
|
|
|
|
|
|
|
|
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
|
|
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
|
|
|
|
|
|
|
@@ -6904,12 +6989,13 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
|
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
|
|
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
|
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
|
|
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
|
|
|
|
|
|
+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
|
|
|
const uint64_t ids_sz = nbi2;
|
|
const uint64_t ids_sz = nbi2;
|
|
|
const uint64_t d_sz = sizeof(float) * d_ne;
|
|
const uint64_t d_sz = sizeof(float) * d_ne;
|
|
|
|
|
|
|
|
vk_pipeline to_fp16_vk_0 = nullptr;
|
|
vk_pipeline to_fp16_vk_0 = nullptr;
|
|
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
|
|
|
|
+ vk_pipeline to_q8_1 = nullptr;
|
|
|
|
|
|
|
|
if (x_non_contig) {
|
|
if (x_non_contig) {
|
|
|
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
|
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
|
@@ -6924,9 +7010,16 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
|
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
|
|
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
|
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
|
|
|
|
|
|
|
|
|
+ if (quantize_y) {
|
|
|
|
|
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if (dryrun) {
|
|
if (dryrun) {
|
|
|
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
|
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
|
|
- const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
|
|
|
|
|
|
+ uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
|
|
|
|
+ if (quantize_y) {
|
|
|
|
|
+ y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144;
|
|
|
|
|
+ }
|
|
|
if (
|
|
if (
|
|
|
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
|
|
(qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) ||
|
|
|
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
|
|
(qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) {
|
|
@@ -6935,7 +7028,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
|
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
|
|
ctx->prealloc_size_x = x_sz_upd;
|
|
ctx->prealloc_size_x = x_sz_upd;
|
|
|
}
|
|
}
|
|
|
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
|
|
|
|
|
|
|
+ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
|
|
|
ctx->prealloc_size_y = y_sz_upd;
|
|
ctx->prealloc_size_y = y_sz_upd;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -6947,6 +7040,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
if (qy_needs_dequant) {
|
|
if (qy_needs_dequant) {
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
|
|
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
|
|
|
}
|
|
}
|
|
|
|
|
+ if (quantize_y) {
|
|
|
|
|
+ ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
|
|
|
|
+ }
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -6983,6 +7079,9 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
if (qy_needs_dequant) {
|
|
if (qy_needs_dequant) {
|
|
|
d_Y = ctx->prealloc_y;
|
|
d_Y = ctx->prealloc_y;
|
|
|
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
|
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
|
|
|
|
+ } else if (quantize_y) {
|
|
|
|
|
+ d_Y = ctx->prealloc_y;
|
|
|
|
|
+ GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144);
|
|
|
} else {
|
|
} else {
|
|
|
d_Y = d_Qy;
|
|
d_Y = d_Qy;
|
|
|
y_buf_offset = qy_buf_offset;
|
|
y_buf_offset = qy_buf_offset;
|
|
@@ -7014,6 +7113,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
ctx->prealloc_y_last_tensor_used = src1;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+ if (quantize_y) {
|
|
|
|
|
+ if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
|
|
|
|
|
+ ctx->prealloc_y_last_tensor_used != src1) {
|
|
|
|
|
+ if (ctx->prealloc_y_need_sync) {
|
|
|
|
|
+ ggml_vk_sync_buffers(ctx, subctx);
|
|
|
|
|
+ }
|
|
|
|
|
+ ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true);
|
|
|
|
|
+ ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
|
|
|
|
|
+ ctx->prealloc_y_last_tensor_used = src1;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
uint32_t stride_batch_x = ne00*ne01;
|
|
uint32_t stride_batch_x = ne00*ne01;
|
|
|
uint32_t stride_batch_y = ne10*ne11;
|
|
uint32_t stride_batch_y = ne10*ne11;
|
|
@@ -7022,14 +7132,19 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
|
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
|
|
|
|
|
|
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
|
|
|
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
|
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ uint32_t y_sz_total = y_sz * ne12 * ne13;
|
|
|
|
|
+ if (quantize_y) {
|
|
|
|
|
+ y_sz_total = CEIL_DIV(y_sz_total, 144) * 144;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// compute
|
|
// compute
|
|
|
ggml_vk_matmul_id(
|
|
ggml_vk_matmul_id(
|
|
|
ctx, subctx, pipeline,
|
|
ctx, subctx, pipeline,
|
|
|
- { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 },
|
|
|
|
|
|
|
+ { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total },
|
|
|
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
|
|
{ d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
|
|
|
ne01, ne21, ne10, ne10, ne10, ne01,
|
|
ne01, ne21, ne10, ne10, ne10, ne01,
|
|
|
stride_batch_x, stride_batch_y, ne20*ne21,
|
|
stride_batch_x, stride_batch_y, ne20*ne21,
|