|
|
@@ -241,6 +241,8 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
|
|
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
|
|
+ GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
|
|
+ GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_PAD_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_ARANGE_F32,
|
|
|
@@ -272,6 +274,8 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_SIN,
|
|
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
|
+ GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
|
+ GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_COUNT
|
|
|
};
|
|
|
@@ -685,6 +689,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true);
|
|
|
@@ -716,6 +722,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
|
}
|
|
|
|
|
|
[metal_library release];
|
|
|
@@ -844,8 +852,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
case GGML_OP_IM2COL:
|
|
|
return op->src[0]->type == GGML_TYPE_F16;
|
|
|
case GGML_OP_POOL_1D:
|
|
|
- case GGML_OP_POOL_2D:
|
|
|
return false;
|
|
|
+ case GGML_OP_POOL_2D:
|
|
|
case GGML_OP_UPSCALE:
|
|
|
case GGML_OP_PAD:
|
|
|
case GGML_OP_ARANGE:
|
|
|
@@ -2545,6 +2553,8 @@ static void ggml_metal_encode_node(
|
|
|
} break;
|
|
|
case GGML_OP_IM2COL:
|
|
|
{
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src1));
|
|
|
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
|
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
|
|
|
@@ -2574,30 +2584,54 @@ static void ggml_metal_encode_node(
|
|
|
const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
|
|
|
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
|
|
|
|
|
- id<MTLComputePipelineState> pipeline = nil;
|
|
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
|
|
+
|
|
|
+ const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
|
|
|
|
|
|
switch (dst->type) {
|
|
|
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
|
|
|
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
|
|
|
+ case GGML_TYPE_F32: {
|
|
|
+ pipeline = (is_gt_mttpt ?
|
|
|
+ ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
|
|
|
+ :
|
|
|
+ ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
|
|
|
+ } break;
|
|
|
+ case GGML_TYPE_F16: {
|
|
|
+ pipeline = (is_gt_mttpt ?
|
|
|
+ ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
|
|
|
+ :
|
|
|
+ ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
|
|
|
+ } break;
|
|
|
default: GGML_ABORT("fatal error");
|
|
|
};
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
|
|
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
- [encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
|
|
- [encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
|
|
- [encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
|
|
- [encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
|
|
- [encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
|
|
- [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
|
|
- [encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
|
|
- [encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
|
|
- [encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
|
|
- [encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
|
|
- [encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
|
|
-
|
|
|
- [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
|
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+ [encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
|
|
|
+ [encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
|
|
+ [encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
|
|
+ [encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
|
|
+ [encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
|
|
+ [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
|
|
+ [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
|
|
+ [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
|
|
+ [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
|
|
+ [encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
|
|
+ [encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
|
|
+
|
|
|
+ if (is_gt_mttpt) {
|
|
|
+ [encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
|
|
+ [encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
|
|
+ [encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
|
|
+
|
|
|
+ const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
|
|
+
|
|
|
+ const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
|
|
+ } else {
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
|
|
+ }
|
|
|
} break;
|
|
|
case GGML_OP_UPSCALE:
|
|
|
{
|
|
|
@@ -3001,6 +3035,64 @@ static void ggml_metal_encode_node(
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
} break;
|
|
|
+ case GGML_OP_POOL_2D:
|
|
|
+ {
|
|
|
+ GGML_ASSERT(ggml_is_contiguous(src0));
|
|
|
+ GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
|
|
|
+
|
|
|
+ const int32_t * opts = dst->op_params;
|
|
|
+ enum ggml_op_pool op = opts[0];
|
|
|
+
|
|
|
+ id<MTLComputePipelineState> pipeline = nil;
|
|
|
+ switch (src0t) {
|
|
|
+ case GGML_TYPE_F32: {
|
|
|
+ switch(op) {
|
|
|
+ case GGML_OP_POOL_AVG:
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
|
|
|
+ case GGML_OP_POOL_MAX:
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
|
|
|
+ default: GGML_ASSERT(false && "not implemented");
|
|
|
+ }
|
|
|
+ } break;
|
|
|
+ default: GGML_ASSERT(false && "not implemented");
|
|
|
+ }
|
|
|
+
|
|
|
+ const int32_t k0 = opts[1];
|
|
|
+ const int32_t k1 = opts[2];
|
|
|
+ const int32_t s0 = opts[3];
|
|
|
+ const int32_t s1 = opts[4];
|
|
|
+ const int32_t p0 = opts[5];
|
|
|
+ const int32_t p1 = opts[6];
|
|
|
+
|
|
|
+ const int64_t IH = src0->ne[1];
|
|
|
+ const int64_t IW = src0->ne[0];
|
|
|
+
|
|
|
+ const int64_t N = dst->ne[3];
|
|
|
+ const int64_t OC = dst->ne[2];
|
|
|
+ const int64_t OH = dst->ne[1];
|
|
|
+ const int64_t OW = dst->ne[0];
|
|
|
+
|
|
|
+ const int64_t parallel_elements = N * OC * OH * OW;
|
|
|
+ const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
|
|
|
+ const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+ [encoder setBytes:&k0 length:sizeof(int32_t) atIndex:2];
|
|
|
+ [encoder setBytes:&k1 length:sizeof(int32_t) atIndex:3];
|
|
|
+ [encoder setBytes:&s0 length:sizeof(int32_t) atIndex:4];
|
|
|
+ [encoder setBytes:&s1 length:sizeof(int32_t) atIndex:5];
|
|
|
+ [encoder setBytes:&p0 length:sizeof(int32_t) atIndex:6];
|
|
|
+ [encoder setBytes:&p1 length:sizeof(int32_t) atIndex:7];
|
|
|
+ [encoder setBytes:&IH length:sizeof(int64_t) atIndex:8];
|
|
|
+ [encoder setBytes:&IW length:sizeof(int64_t) atIndex:9];
|
|
|
+ [encoder setBytes:&OH length:sizeof(int64_t) atIndex:10];
|
|
|
+ [encoder setBytes:&OW length:sizeof(int64_t) atIndex:11];
|
|
|
+ [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
|
|
+ } break;
|
|
|
default:
|
|
|
{
|
|
|
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|