|
|
@@ -418,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
{
|
|
|
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
|
|
|
} break;
|
|
|
+ case GGML_OP_OPT_STEP_SGD:
|
|
|
+ {
|
|
|
+ n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
|
|
|
+ } break;
|
|
|
default:
|
|
|
{
|
|
|
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
|
|
|
@@ -3469,3 +3473,37 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
|
|
|
|
|
|
return 1;
|
|
|
}
|
|
|
+
|
|
|
+int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
|
|
|
+ ggml_tensor * op = ctx->node(idx);
|
|
|
+
|
|
|
+ ggml_metal_library_t lib = ctx->lib;
|
|
|
+ ggml_metal_encoder_t enc = ctx->enc;
|
|
|
+
|
|
|
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
|
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
|
+ GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
|
|
+
|
|
|
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
|
|
|
+
|
|
|
+ const int64_t np = ggml_nelements(op->src[0]);
|
|
|
+ ggml_metal_kargs_opt_step_sgd args = {
|
|
|
+ /*.np =*/ np,
|
|
|
+ };
|
|
|
+
|
|
|
+ int ida = 0;
|
|
|
+
|
|
|
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
|
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
|
|
|
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
|
|
|
+
|
|
|
+ const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
|
|
|
+ const int64_t n = (np + nth - 1) / nth;
|
|
|
+
|
|
|
+ ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
|
|
|
+
|
|
|
+ return 1;
|
|
|
+}
|