|
@@ -65,6 +65,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
|
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
|
|
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
|
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
|
@@ -91,6 +92,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
|
|
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
|
//GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
|
|
@@ -113,6 +115,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
|
|
@@ -132,6 +135,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
|
|
@@ -151,6 +155,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
|
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
|
GGML_METAL_KERNEL_TYPE_ROPE_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
|
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
@@ -466,6 +471,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
|
@@ -492,6 +498,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, mul_mv_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
|
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
|
@@ -514,6 +521,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, ctx->support_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
|
@@ -533,6 +541,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
|
@@ -552,6 +561,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32, mul_mm_id_iq4_xs_f32, ctx->support_simdgroup_mm);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
@@ -1371,6 +1381,7 @@ static bool ggml_metal_graph_compute(
|
|
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break;
|
|
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break;
|
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break;
|
|
|
|
|
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break;
|
|
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1529,6 +1540,12 @@ static bool ggml_metal_graph_compute(
|
|
|
nth1 = 16;
|
|
nth1 = 16;
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_TYPE_IQ4_XS:
|
|
|
|
|
+ {
|
|
|
|
|
+ nth0 = 4;
|
|
|
|
|
+ nth1 = 16;
|
|
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
|
|
|
|
|
+ } break;
|
|
|
default:
|
|
default:
|
|
|
{
|
|
{
|
|
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
|
|
@@ -1576,7 +1593,7 @@ static bool ggml_metal_graph_compute(
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
}
|
|
}
|
|
|
- else if (src0t == GGML_TYPE_IQ4_NL) {
|
|
|
|
|
|
|
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
|
|
const int mem_size = 32*sizeof(float);
|
|
const int mem_size = 32*sizeof(float);
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
@@ -1678,6 +1695,7 @@ static bool ggml_metal_graph_compute(
|
|
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline; break;
|
|
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline; break;
|
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline; break;
|
|
|
|
|
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32 ].pipeline; break;
|
|
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -1839,6 +1857,12 @@ static bool ggml_metal_graph_compute(
|
|
|
nth1 = 16;
|
|
nth1 = 16;
|
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_TYPE_IQ4_XS:
|
|
|
|
|
+ {
|
|
|
|
|
+ nth0 = 4;
|
|
|
|
|
+ nth1 = 16;
|
|
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
|
|
|
|
|
+ } break;
|
|
|
default:
|
|
default:
|
|
|
{
|
|
{
|
|
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t);
|
|
@@ -1902,7 +1926,7 @@ static bool ggml_metal_graph_compute(
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
|
}
|
|
}
|
|
|
- else if (src2t == GGML_TYPE_IQ4_NL) {
|
|
|
|
|
|
|
+ else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
|
|
|
const int mem_size = 32*sizeof(float);
|
|
const int mem_size = 32*sizeof(float);
|
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
|
@@ -1952,6 +1976,7 @@ static bool ggml_metal_graph_compute(
|
|
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
|
case GGML_TYPE_IQ2_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline; break;
|
|
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
|
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline; break;
|
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
|
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline; break;
|
|
|
|
|
+ case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS ].pipeline; break;
|
|
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
|
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline; break;
|
|
|
default: GGML_ASSERT(false && "not implemented");
|
|
default: GGML_ASSERT(false && "not implemented");
|
|
|
}
|
|
}
|