|
|
@@ -306,6 +306,8 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
|
|
+ GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32,
|
|
|
+ GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
|
|
@@ -870,6 +872,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
|
|
@@ -1069,6 +1073,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
case GGML_OP_REPEAT:
|
|
|
case GGML_OP_SCALE:
|
|
|
case GGML_OP_CLAMP:
|
|
|
+ case GGML_OP_CONV_TRANSPOSE_1D:
|
|
|
return true;
|
|
|
case GGML_OP_SQR:
|
|
|
case GGML_OP_SQRT:
|
|
|
@@ -3138,6 +3143,49 @@ static void ggml_metal_encode_node(
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
|
}
|
|
|
} break;
|
|
|
+ case GGML_OP_CONV_TRANSPOSE_1D:
|
|
|
+ {
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src1));
|
|
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
+
|
|
|
+ const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
|
|
|
+
|
|
|
+ const int32_t IC = src1->ne[1];
|
|
|
+ const int32_t IL = src1->ne[0];
|
|
|
+
|
|
|
+ const int32_t K = src0->ne[0];
|
|
|
+
|
|
|
+ const int32_t OL = dst->ne[0];
|
|
|
+ const int32_t OC = dst->ne[1];
|
|
|
+
|
|
|
+ id<MTLComputePipelineState> pipeline;
|
|
|
+
|
|
|
+ switch (src0->type) {
|
|
|
+ case GGML_TYPE_F32: {
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline;
|
|
|
+ } break;
|
|
|
+ case GGML_TYPE_F16: {
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline;
|
|
|
+ } break;
|
|
|
+ default: GGML_ABORT("fatal error");
|
|
|
+ };
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
+ [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3];
|
|
|
+ [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4];
|
|
|
+ [encoder setBytes:&K length:sizeof( int32_t) atIndex:5];
|
|
|
+ [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6];
|
|
|
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7];
|
|
|
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8];
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
+ } break;
|
|
|
case GGML_OP_UPSCALE:
|
|
|
{
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|