|
|
@@ -392,6 +392,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
|
+ GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
|
|
|
|
GGML_METAL_KERNEL_TYPE_COUNT
|
|
|
};
|
|
|
@@ -956,6 +957,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
|
}
|
|
|
@@ -1086,6 +1088,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
return has_simdgroup_reduction;
|
|
|
case GGML_OP_RMS_NORM:
|
|
|
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
|
|
|
+ case GGML_OP_ARGMAX:
|
|
|
case GGML_OP_NORM:
|
|
|
case GGML_OP_ROPE:
|
|
|
return true;
|
|
|
@@ -3845,6 +3848,31 @@ static void ggml_metal_encode_node(
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
|
|
} break;
|
|
|
+ case GGML_OP_ARGMAX:
|
|
|
+ {
|
|
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
|
|
|
+ GGML_ASSERT(nb00 == ggml_type_size(src0->type));
|
|
|
+
|
|
|
+ const int64_t nrows = ggml_nrows(src0);
|
|
|
+
|
|
|
+ int nth = 32; // SIMD width
|
|
|
+ while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
|
|
|
+ nth *= 2;
|
|
|
+ }
|
|
|
+
|
|
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
|
|
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
|
|
|
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
|
+ [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
+ } break;
|
|
|
default:
|
|
|
{
|
|
|
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|