|
@@ -266,10 +266,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
{
|
|
{
|
|
|
n_fuse = ggml_metal_op_set_rows(ctx, idx);
|
|
n_fuse = ggml_metal_op_set_rows(ctx, idx);
|
|
|
} break;
|
|
} break;
|
|
|
- case GGML_OP_RMS_NORM:
|
|
|
|
|
- {
|
|
|
|
|
- n_fuse = ggml_metal_op_rms_norm(ctx, idx);
|
|
|
|
|
- } break;
|
|
|
|
|
case GGML_OP_L2_NORM:
|
|
case GGML_OP_L2_NORM:
|
|
|
{
|
|
{
|
|
|
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
|
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
|
|
@@ -279,6 +275,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
|
|
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
|
n_fuse = ggml_metal_op_group_norm(ctx, idx);
|
|
|
} break;
|
|
} break;
|
|
|
case GGML_OP_NORM:
|
|
case GGML_OP_NORM:
|
|
|
|
|
+ case GGML_OP_RMS_NORM:
|
|
|
{
|
|
{
|
|
|
n_fuse = ggml_metal_op_norm(ctx, idx);
|
|
n_fuse = ggml_metal_op_norm(ctx, idx);
|
|
|
} break;
|
|
} break;
|
|
@@ -2346,146 +2343,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
|
|
return n_fuse;
|
|
return n_fuse;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-int ggml_metal_op_rms_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
|
|
- ggml_cgraph * gf = ctx->gf;
|
|
|
|
|
- ggml_tensor * op = ggml_graph_node(gf, idx);
|
|
|
|
|
-
|
|
|
|
|
- ggml_metal_library_t lib = ctx->lib;
|
|
|
|
|
- ggml_metal_encoder_t enc = ctx->enc;
|
|
|
|
|
-
|
|
|
|
|
- const int idx_end = ctx->idx_end;
|
|
|
|
|
-
|
|
|
|
|
- const bool use_fusion = ctx->use_fusion;
|
|
|
|
|
-
|
|
|
|
|
- const int debug_fusion = ctx->debug_fusion;
|
|
|
|
|
-
|
|
|
|
|
- ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
|
|
|
|
|
-
|
|
|
|
|
- GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
|
|
|
- GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
|
|
|
- GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
|
|
|
- GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
|
|
|
|
|
-
|
|
|
|
|
- float eps;
|
|
|
|
|
- memcpy(&eps, op->op_params, sizeof(float));
|
|
|
|
|
-
|
|
|
|
|
- ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
|
|
|
- ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
|
|
|
-
|
|
|
|
|
- ggml_metal_kargs_rms_norm args = {
|
|
|
|
|
- /*.ne00 =*/ ne00,
|
|
|
|
|
- /*.ne00_4 =*/ ne00/4,
|
|
|
|
|
- /*.nb1 =*/ nb1,
|
|
|
|
|
- /*.nb2 =*/ nb2,
|
|
|
|
|
- /*.nb3 =*/ nb3,
|
|
|
|
|
- /*.eps =*/ eps,
|
|
|
|
|
- /*.nef1 =*/ { ne01 },
|
|
|
|
|
- /*.nef2 =*/ { ne02 },
|
|
|
|
|
- /*.nef3 =*/ { ne03 },
|
|
|
|
|
- /*.nbf1 =*/ { nb01 },
|
|
|
|
|
- /*.nbf2 =*/ { nb02 },
|
|
|
|
|
- /*.nbf3 =*/ { nb03 },
|
|
|
|
|
- };
|
|
|
|
|
-
|
|
|
|
|
- ggml_op fops[8];
|
|
|
|
|
-
|
|
|
|
|
- int n_fuse = 1;
|
|
|
|
|
-
|
|
|
|
|
- ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
|
|
|
|
|
-
|
|
|
|
|
- // d[0] = rms_norm(a)
|
|
|
|
|
- // d[1] = mul(d[0], b)
|
|
|
|
|
- // d[2] = add(d[1], c)
|
|
|
|
|
- if (use_fusion) {
|
|
|
|
|
- fops[0] = GGML_OP_RMS_NORM;
|
|
|
|
|
- fops[1] = GGML_OP_MUL;
|
|
|
|
|
- fops[2] = GGML_OP_ADD;
|
|
|
|
|
-
|
|
|
|
|
- for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
|
|
|
|
- if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
|
|
|
|
|
-
|
|
|
|
|
- bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
|
|
|
|
|
-
|
|
|
|
|
- args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
|
|
|
|
|
- args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
|
|
|
|
|
- args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
|
|
|
|
|
-
|
|
|
|
|
- args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
|
|
|
|
|
- args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
|
|
|
|
|
- args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- ++n_fuse;
|
|
|
|
|
-
|
|
|
|
|
- if (debug_fusion > 1 && n_fuse > 1) {
|
|
|
|
|
- if (n_fuse == 2) {
|
|
|
|
|
- GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL\n", __func__);
|
|
|
|
|
- }
|
|
|
|
|
- if (n_fuse == 3) {
|
|
|
|
|
- GGML_LOG_DEBUG("%s: fuse: RMS_NORM + MUL + ADD\n", __func__);
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (n_fuse > 1) {
|
|
|
|
|
- bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
|
|
|
|
|
-
|
|
|
|
|
- for (int i = 1; i < n_fuse; ++i) {
|
|
|
|
|
- if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
|
|
|
|
|
- ggml_metal_op_concurrency_reset(ctx);
|
|
|
|
|
-
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rms_norm(lib, op, n_fuse);
|
|
|
|
|
-
|
|
|
|
|
- int nth = 32; // SIMD width
|
|
|
|
|
-
|
|
|
|
|
- while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
|
|
|
- nth *= 2;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
|
|
|
- nth = std::min(nth, ne00/4);
|
|
|
|
|
-
|
|
|
|
|
- const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
|
|
|
|
-
|
|
|
|
|
- ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
|
|
|
- ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
|
|
|
- ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
|
|
|
- ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
|
|
|
|
|
- ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
|
|
|
|
|
- ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
|
|
|
|
-
|
|
|
|
|
- ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
|
|
|
-
|
|
|
|
|
- ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
|
|
|
-
|
|
|
|
|
- return n_fuse;
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
|
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
ggml_cgraph * gf = ctx->gf;
|
|
ggml_cgraph * gf = ctx->gf;
|
|
|
ggml_tensor * op = ggml_graph_node(gf, idx);
|
|
ggml_tensor * op = ggml_graph_node(gf, idx);
|
|
@@ -2594,6 +2451,14 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
ggml_metal_library_t lib = ctx->lib;
|
|
ggml_metal_library_t lib = ctx->lib;
|
|
|
ggml_metal_encoder_t enc = ctx->enc;
|
|
ggml_metal_encoder_t enc = ctx->enc;
|
|
|
|
|
|
|
|
|
|
+ const int idx_end = ctx->idx_end;
|
|
|
|
|
+
|
|
|
|
|
+ const bool use_fusion = ctx->use_fusion;
|
|
|
|
|
+
|
|
|
|
|
+ const int debug_fusion = ctx->debug_fusion;
|
|
|
|
|
+
|
|
|
|
|
+ ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
|
|
|
|
|
+
|
|
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
|
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
|
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
|
@@ -2602,37 +2467,121 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
|
|
float eps;
|
|
float eps;
|
|
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
memcpy(&eps, op->op_params, sizeof(float));
|
|
|
|
|
|
|
|
|
|
+ ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
|
|
|
|
+ ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
|
|
|
|
+
|
|
|
ggml_metal_kargs_norm args = {
|
|
ggml_metal_kargs_norm args = {
|
|
|
/*.ne00 =*/ ne00,
|
|
/*.ne00 =*/ ne00,
|
|
|
- /*.ne00_4 =*/ ne00/4,
|
|
|
|
|
- /*.nb01 =*/ nb01,
|
|
|
|
|
|
|
+ /*.ne00_t =*/ ne00 % 4 == 0 ? ne00/4 : ne00,
|
|
|
|
|
+ /*.nb1 =*/ nb1,
|
|
|
|
|
+ /*.nb2 =*/ nb2,
|
|
|
|
|
+ /*.nb3 =*/ nb3,
|
|
|
/*.eps =*/ eps,
|
|
/*.eps =*/ eps,
|
|
|
|
|
+ /*.nef1 =*/ { ne01 },
|
|
|
|
|
+ /*.nef2 =*/ { ne02 },
|
|
|
|
|
+ /*.nef3 =*/ { ne03 },
|
|
|
|
|
+ /*.nbf1 =*/ { nb01 },
|
|
|
|
|
+ /*.nbf2 =*/ { nb02 },
|
|
|
|
|
+ /*.nbf3 =*/ { nb03 },
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op);
|
|
|
|
|
|
|
+ ggml_op fops[8];
|
|
|
|
|
+
|
|
|
|
|
+ int n_fuse = 1;
|
|
|
|
|
+
|
|
|
|
|
+ ggml_metal_buffer_id bid_fuse[2] = { bid_src0, bid_src0 };
|
|
|
|
|
+
|
|
|
|
|
+ // d[0] = norm(a)
|
|
|
|
|
+ // d[1] = mul(d[0], b)
|
|
|
|
|
+ // d[2] = add(d[1], c)
|
|
|
|
|
+ if (use_fusion) {
|
|
|
|
|
+ fops[0] = op->op;
|
|
|
|
|
+ fops[1] = GGML_OP_MUL;
|
|
|
|
|
+ fops[2] = GGML_OP_ADD;
|
|
|
|
|
+
|
|
|
|
|
+ for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
|
|
|
|
+ if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ //ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
|
|
|
|
|
+
|
|
|
|
|
+ bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
|
|
|
|
|
+
|
|
|
|
|
+ args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
|
|
|
|
|
+ args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
|
|
|
|
|
+ args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
|
|
|
|
|
+
|
|
|
|
|
+ args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
|
|
|
|
|
+ args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
|
|
|
|
|
+ args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ ++n_fuse;
|
|
|
|
|
+
|
|
|
|
|
+ if (debug_fusion > 1 && n_fuse > 1) {
|
|
|
|
|
+ if (n_fuse == 2) {
|
|
|
|
|
+ GGML_LOG_DEBUG("%s: fuse: %s + MUL\n", __func__, ggml_op_name(op->op));
|
|
|
|
|
+ }
|
|
|
|
|
+ if (n_fuse == 3) {
|
|
|
|
|
+ GGML_LOG_DEBUG("%s: fuse: %s + MUL + ADD\n", __func__, ggml_op_name(op->op));
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (n_fuse > 1) {
|
|
|
|
|
+ bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 1; i < n_fuse; ++i) {
|
|
|
|
|
+ if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
|
|
|
|
|
+ ggml_metal_op_concurrency_reset(ctx);
|
|
|
|
|
+
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
|
|
|
|
|
|
|
|
int nth = 32; // SIMD width
|
|
int nth = 32; // SIMD width
|
|
|
- while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
|
|
|
|
|
+
|
|
|
|
|
+ while (nth < args.ne00_t && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
|
|
|
nth *= 2;
|
|
nth *= 2;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
|
|
- nth = std::min(nth, ne00/4);
|
|
|
|
|
|
|
+ nth = std::min(nth, args.ne00_t);
|
|
|
|
|
|
|
|
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
|
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
|
|
|
|
|
|
|
|
- const int64_t nrows = ggml_nrows(op->src[0]);
|
|
|
|
|
-
|
|
|
|
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
|
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
|
|
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
|
|
|
|
- ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
|
|
|
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
|
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_fuse[0], 2);
|
|
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_fuse[1], 3);
|
|
|
|
|
+ ggml_metal_encoder_set_buffer (enc, bid_dst, 4);
|
|
|
|
|
|
|
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
|
|
|
|
|
|
|
- ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
|
|
|
|
|
|
|
+ ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
|
|
|
|
|
|
|
|
- return 1;
|
|
|
|
|
|
|
+ return n_fuse;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
|
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|