|
|
@@ -310,6 +310,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
|
+ GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
|
|
@@ -877,6 +878,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
|
|
@@ -1099,6 +1101,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
case GGML_OP_POOL_2D:
|
|
|
case GGML_OP_UPSCALE:
|
|
|
case GGML_OP_PAD:
|
|
|
+ case GGML_OP_PAD_REFLECT_1D:
|
|
|
case GGML_OP_ARANGE:
|
|
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
|
case GGML_OP_ARGSORT:
|
|
|
@@ -3258,6 +3261,38 @@ static void ggml_metal_encode_node(
|
|
|
|
|
|
const int nth = MIN(1024, ne0);
|
|
|
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
+ } break;
|
|
|
+ case GGML_OP_PAD_REFLECT_1D:
|
|
|
+ {
|
|
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
+
|
|
|
+ const int32_t p0 = ((const int32_t *)(dst->op_params))[0];
|
|
|
+ const int32_t p1 = ((const int32_t *)(dst->op_params))[1];
|
|
|
+
|
|
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
|
|
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
|
|
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
|
|
|
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
|
|
|
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6];
|
|
|
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
|
|
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
|
|
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
|
|
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
|
|
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11];
|
|
|
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12];
|
|
|
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13];
|
|
|
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14];
|
|
|
+ [encoder setBytes:&p0 length:sizeof(p0) atIndex:15];
|
|
|
+ [encoder setBytes:&p1 length:sizeof(p1) atIndex:16];
|
|
|
+
|
|
|
+ const int nth = MIN(1024, ne0);
|
|
|
+
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
} break;
|
|
|
case GGML_OP_ARANGE:
|