|
@@ -135,6 +135,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
|
GGML_METAL_KERNEL_TYPE_ROPE_F16,
|
|
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
@@ -506,6 +507,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
@@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|
|
case GGML_OP_ALIBI:
|
|
case GGML_OP_ALIBI:
|
|
|
case GGML_OP_ROPE:
|
|
case GGML_OP_ROPE:
|
|
|
case GGML_OP_IM2COL:
|
|
case GGML_OP_IM2COL:
|
|
|
|
|
+ return true;
|
|
|
|
|
+ case GGML_OP_POOL_1D:
|
|
|
|
|
+ case GGML_OP_POOL_2D:
|
|
|
|
|
+ return false;
|
|
|
case GGML_OP_UPSCALE:
|
|
case GGML_OP_UPSCALE:
|
|
|
case GGML_OP_PAD:
|
|
case GGML_OP_PAD:
|
|
|
case GGML_OP_ARGSORT:
|
|
case GGML_OP_ARGSORT:
|
|
@@ -2015,7 +2021,7 @@ static bool ggml_metal_graph_compute(
|
|
|
{
|
|
{
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
- GGML_ASSERT( dst->type == GGML_TYPE_F16);
|
|
|
|
|
|
|
+ GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
|
|
@@ -2023,6 +2029,7 @@ static bool ggml_metal_graph_compute(
|
|
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
|
|
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
|
|
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
|
|
|
|
|
+
|
|
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
|
|
|
|
|
|
|
|
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
const int32_t N = src1->ne[is_2D ? 3 : 2];
|
|
@@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute(
|
|
|
|
|
|
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
|
|
|
- switch (src0->type) {
|
|
|
|
|
- case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
|
|
|
|
|
|
|
+ switch (dst->type) {
|
|
|
|
|
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
|
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
|
|
default: GGML_ASSERT(false);
|
|
default: GGML_ASSERT(false);
|
|
|
};
|
|
};
|