|
|
@@ -47,10 +47,11 @@ struct ggml_metal_context {
|
|
|
GGML_METAL_DECL_KERNEL(relu);
|
|
|
GGML_METAL_DECL_KERNEL(soft_max);
|
|
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
|
+ GGML_METAL_DECL_KERNEL(get_rows_f16);
|
|
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
|
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
|
|
- GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
|
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
|
|
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
|
|
GGML_METAL_DECL_KERNEL(rope);
|
|
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
|
|
@@ -130,10 +131,11 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
|
GGML_METAL_ADD_KERNEL(relu);
|
|
|
GGML_METAL_ADD_KERNEL(soft_max);
|
|
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
|
|
+ GGML_METAL_ADD_KERNEL(get_rows_f16);
|
|
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
|
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
|
|
- GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
|
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
|
|
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
|
|
GGML_METAL_ADD_KERNEL(rope);
|
|
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
|
|
@@ -498,6 +500,14 @@ void ggml_metal_graph_compute(
|
|
|
|
|
|
// use custom matrix x vector kernel
|
|
|
switch (src0t) {
|
|
|
+ case GGML_TYPE_F16:
|
|
|
+ {
|
|
|
+ GGML_ASSERT(ne02 == ne12);
|
|
|
+
|
|
|
+ nth0 = 64;
|
|
|
+ nth1 = 1;
|
|
|
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
|
|
+ } break;
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
{
|
|
|
GGML_ASSERT(ne02 == 1);
|
|
|
@@ -507,14 +517,6 @@ void ggml_metal_graph_compute(
|
|
|
nth1 = 4;
|
|
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
|
|
} break;
|
|
|
- case GGML_TYPE_F16:
|
|
|
- {
|
|
|
- GGML_ASSERT(ne02 == ne12);
|
|
|
-
|
|
|
- nth0 = 32;
|
|
|
- nth1 = 1;
|
|
|
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
|
|
- } break;
|
|
|
default: GGML_ASSERT(false && "not implemented");
|
|
|
};
|
|
|
|
|
|
@@ -551,6 +553,7 @@ void ggml_metal_graph_compute(
|
|
|
}
|
|
|
|
|
|
switch (src0->type) {
|
|
|
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
|
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
|
|
default: GGML_ASSERT(false && "not implemented");
|
|
|
}
|