|
|
@@ -466,6 +466,14 @@ static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_ed
|
|
|
{ 2, 0, 1 }, // set_rows->src[0] == view
|
|
|
};
|
|
|
|
|
|
+static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_view_set_rows_edges {
|
|
|
+ { 1, 0, 0 }, // mul->src[0] == rms
|
|
|
+ { 2, 0, 1 }, // rope->src[0] == mul
|
|
|
+ { 3, 0, 2 }, // view->src[0] == rope
|
|
|
+ { 4, 0, 3 }, // set_rows->src[0] == view
|
|
|
+};
|
|
|
+
|
|
|
+
|
|
|
struct vk_device_struct {
|
|
|
std::recursive_mutex mutex;
|
|
|
|
|
|
@@ -617,6 +625,8 @@ struct vk_device_struct {
|
|
|
vk_pipeline pipeline_rms_norm_mul_f32;
|
|
|
vk_pipeline pipeline_rms_norm_partials_f32;
|
|
|
vk_pipeline pipeline_rms_norm_mul_partials_f32;
|
|
|
+ vk_pipeline pipeline_rms_norm_mul_rope_f32_f32;
|
|
|
+ vk_pipeline pipeline_rms_norm_mul_rope_f32_f16;
|
|
|
vk_pipeline pipeline_rms_norm_back_f32;
|
|
|
vk_pipeline pipeline_l2_norm_f32;
|
|
|
|
|
|
@@ -1060,6 +1070,7 @@ struct vk_op_diag_mask_push_constants {
|
|
|
};
|
|
|
|
|
|
struct vk_op_rope_push_constants {
|
|
|
+ uint32_t rope_mode;
|
|
|
uint32_t ncols;
|
|
|
uint32_t n_dims;
|
|
|
float freq_scale;
|
|
|
@@ -1079,6 +1090,12 @@ struct vk_op_rope_push_constants {
|
|
|
uint32_t set_rows_stride;
|
|
|
};
|
|
|
|
|
|
+// For fused rms_norm+mul+rope(+view+set_rows)
|
|
|
+struct vk_op_rms_norm_mul_rope_push_constants {
|
|
|
+ vk_op_binary_push_constants bin;
|
|
|
+ vk_op_rope_push_constants rope;
|
|
|
+};
|
|
|
+
|
|
|
struct vk_op_soft_max_push_constants {
|
|
|
uint32_t KX;
|
|
|
uint32_t KY;
|
|
|
@@ -3557,6 +3574,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
|
|
|
|
|
+ if (device->float_controls_rte_fp16 &&
|
|
|
+ sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
|
|
|
+ }
|
|
|
+
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
|
|
|
|
@@ -9590,21 +9613,149 @@ static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const g
|
|
|
return num_bytes;
|
|
|
}
|
|
|
|
|
|
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params) {
|
|
|
+static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) {
|
|
|
+ const int n_dims = ((const int32_t *) dst->op_params)[1];
|
|
|
+ const int mode = ((const int32_t *) dst->op_params)[2];
|
|
|
+ // const int n_ctx = ((const int32_t *) dst->op_params)[3];
|
|
|
+ const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
|
|
|
+ const float freq_base = ((const float *) dst->op_params)[5];
|
|
|
+ const float freq_scale = ((const float *) dst->op_params)[6];
|
|
|
+ const float ext_factor = ((const float *) dst->op_params)[7];
|
|
|
+ const float attn_factor = ((const float *) dst->op_params)[8];
|
|
|
+ const float beta_fast = ((const float *) dst->op_params)[9];
|
|
|
+ const float beta_slow = ((const float *) dst->op_params)[10];
|
|
|
+ int sections[4] {};
|
|
|
+ if (mode & GGML_ROPE_TYPE_MROPE) {
|
|
|
+ memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
|
+ }
|
|
|
+
|
|
|
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
|
|
+
|
|
|
+ float corr_dims[2];
|
|
|
+ ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
|
+
|
|
|
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
+
|
|
|
+ uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
|
|
|
+ uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
|
|
|
+
|
|
|
+ vk_op_rope_push_constants rope {
|
|
|
+ (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
|
|
+ freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
|
|
+ has_ff, (uint32_t)src0->ne[2], nb01, nb02,
|
|
|
+ { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
|
|
|
+ };
|
|
|
+
|
|
|
+ return rope;
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) {
|
|
|
+ ggml_tensor * dst;
|
|
|
+ const ggml_tensor * src0;
|
|
|
+ const ggml_tensor * src1;
|
|
|
+
|
|
|
+ if (ctx->num_additional_fused_ops > 0) {
|
|
|
+ // fused rms_norm + mul
|
|
|
+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
|
+ ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0];
|
|
|
+ dst = mul;
|
|
|
+ src0 = cgraph->nodes[node_idx]->src[0];
|
|
|
+ src1 = other_src;
|
|
|
+ } else {
|
|
|
+ dst = cgraph->nodes[node_idx];
|
|
|
+ src0 = src1 = dst->src[0];
|
|
|
+ }
|
|
|
+
|
|
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
|
|
|
|
uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
|
|
|
|
|
|
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
|
|
|
+ vk_op_binary_push_constants bin {
|
|
|
(uint32_t)ggml_nelements(src0),
|
|
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
|
0,
|
|
|
op_params[0], 0.0f, (int32_t)param3,
|
|
|
- });
|
|
|
+ };
|
|
|
+
|
|
|
+ // more than one fused op means rms_norm+mul+rope
|
|
|
+ if (ctx->num_additional_fused_ops > 1) {
|
|
|
+ static constexpr uint32_t max_tensors = 7;
|
|
|
+ const ggml_tensor *tensors[max_tensors] {};
|
|
|
+
|
|
|
+ ggml_tensor *rms = cgraph->nodes[node_idx + 0];
|
|
|
+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
|
+ ggml_tensor *rope = cgraph->nodes[node_idx + 2];
|
|
|
+
|
|
|
+ ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
|
|
|
+
|
|
|
+ bool do_set_rows = ctx->num_additional_fused_ops == 4;
|
|
|
+
|
|
|
+ tensors[0] = rms->src[0];
|
|
|
+ tensors[1] = other_src;
|
|
|
+ tensors[2] = mul;
|
|
|
+ tensors[3] = rope->src[1]; // pos
|
|
|
+ tensors[4] = rope->src[2]; // ff
|
|
|
+ tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst
|
|
|
+ tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr;
|
|
|
+ const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0;
|
|
|
+
|
|
|
+ vk_op_rms_norm_mul_rope_push_constants pc;
|
|
|
+ pc.bin = bin;
|
|
|
+ pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride);
|
|
|
+
|
|
|
+ vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32;
|
|
|
+
|
|
|
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
|
+
|
|
|
+ ggml_backend_vk_buffer_context * buf_ctx[max_tensors];
|
|
|
+ vk_buffer buf[max_tensors];
|
|
|
+ size_t offset[max_tensors];
|
|
|
+ bool uma[max_tensors];
|
|
|
+
|
|
|
+ for (uint32_t i = 0; i < max_tensors; ++i) {
|
|
|
+ if (!tensors[i]) {
|
|
|
+ // If any remaining descriptors are unused, just point them at src[0]
|
|
|
+ buf[i] = buf[0];
|
|
|
+ offset[i] = 0;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
|
|
|
+ buf[i] = nullptr;
|
|
|
+ offset[i] = 0;
|
|
|
+ uma[i] = false;
|
|
|
+
|
|
|
+ if (ctx->device->uma) {
|
|
|
+ ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
|
|
|
+ uma[i] = buf[i] != nullptr;
|
|
|
+ }
|
|
|
+ if (!uma[i]) {
|
|
|
+ buf[i] = buf_ctx[i]->dev_buffer;
|
|
|
+ offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
|
|
|
+ }
|
|
|
+ GGML_ASSERT(buf[i] != nullptr);
|
|
|
+ }
|
|
|
+
|
|
|
+ std::array<uint32_t, 3> elements;
|
|
|
+ elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
|
|
|
+
|
|
|
+ static_assert(max_tensors == 7);
|
|
|
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
|
+ {
|
|
|
+ ggml_vk_subbuffer(ctx, buf[0], offset[0]),
|
|
|
+ ggml_vk_subbuffer(ctx, buf[1], offset[1]),
|
|
|
+ ggml_vk_subbuffer(ctx, buf[2], offset[2]),
|
|
|
+ ggml_vk_subbuffer(ctx, buf[3], offset[3]),
|
|
|
+ ggml_vk_subbuffer(ctx, buf[4], offset[4]),
|
|
|
+ ggml_vk_subbuffer(ctx, buf[5], offset[5]),
|
|
|
+ ggml_vk_subbuffer(ctx, buf[6], offset[6]),
|
|
|
+ }, pc, elements);
|
|
|
+ } else {
|
|
|
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin));
|
|
|
+ }
|
|
|
|
|
|
if (ctx->do_add_rms_partials_offset_calculation) {
|
|
|
ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
|
|
|
@@ -9758,9 +9909,6 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
// const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
|
const float freq_base = ((float *) dst->op_params)[5];
|
|
|
- const float freq_scale = ((float *) dst->op_params)[6];
|
|
|
- const float ext_factor = ((float *) dst->op_params)[7];
|
|
|
- const float attn_factor = ((float *) dst->op_params)[8];
|
|
|
const float beta_fast = ((float *) dst->op_params)[9];
|
|
|
const float beta_slow = ((float *) dst->op_params)[10];
|
|
|
int sections[4] {};
|
|
|
@@ -9768,16 +9916,9 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
|
}
|
|
|
|
|
|
- const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
|
|
-
|
|
|
float corr_dims[2];
|
|
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
|
|
|
|
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
-
|
|
|
- uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type);
|
|
|
- uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type);
|
|
|
-
|
|
|
uint32_t set_rows_stride = 0;
|
|
|
// Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
|
|
|
// and overrides the dst and sets src3=row_indices
|
|
|
@@ -9787,12 +9928,8 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
dst = cgraph->nodes[node_idx + 2];
|
|
|
}
|
|
|
|
|
|
- ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE, {
|
|
|
- (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
|
|
|
- freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
|
|
|
- src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
|
|
|
- { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
|
|
|
- });
|
|
|
+ ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE,
|
|
|
+ ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
|
|
|
}
|
|
|
|
|
|
static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
|
@@ -11307,6 +11444,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
if (n->op == GGML_OP_GLU) {
|
|
|
std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
|
|
|
}
|
|
|
+ if (n->op == GGML_OP_ROPE) {
|
|
|
+ const int mode = ((const int32_t *) n->op_params)[2];
|
|
|
+ std::cerr << " rope mode: " << mode;
|
|
|
+ }
|
|
|
std::cerr << std::endl;
|
|
|
}
|
|
|
#endif
|
|
|
@@ -11414,14 +11555,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
|
|
|
break;
|
|
|
case GGML_OP_RMS_NORM:
|
|
|
- if (ctx->num_additional_fused_ops > 0) {
|
|
|
- // fused rms_norm + mul
|
|
|
- ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
|
- ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
|
|
|
- ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params);
|
|
|
- } else {
|
|
|
- ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params);
|
|
|
- }
|
|
|
+ ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params);
|
|
|
break;
|
|
|
case GGML_OP_RMS_NORM_BACK:
|
|
|
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node);
|
|
|
@@ -12407,6 +12541,70 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
+// Check whether the tensors overlap in memory but are not equal.
|
|
|
+// Fusions can potenitally overwrite src tensors in ways that are not prevented
|
|
|
+// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
|
|
|
+// to overlap if they are exactly equal.
|
|
|
+// XXX TODO this check is probably missing from several fusion optimizations.
|
|
|
+static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
|
|
|
+ ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
|
|
|
+ vk_buffer a_buf = a_buf_ctx->dev_buffer;
|
|
|
+ ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
|
|
|
+ vk_buffer b_buf = b_buf_ctx->dev_buffer;
|
|
|
+ if (a_buf == b_buf) {
|
|
|
+ auto a_base = vk_tensor_offset(a) + a->view_offs;
|
|
|
+ auto a_size = ggml_nbytes(a);
|
|
|
+ auto b_base = vk_tensor_offset(b) + b->view_offs;
|
|
|
+ auto b_size = ggml_nbytes(b);
|
|
|
+
|
|
|
+ if (a_base == b_base && a_size == b_size) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ if ((b_base <= a_base && a_base < b_base + b_size) ||
|
|
|
+ (a_base <= b_base && b_base < a_base + a_size)) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+}
|
|
|
+
|
|
|
+static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
|
|
+ int node_idx) {
|
|
|
+ GGML_UNUSED(ctx);
|
|
|
+ const ggml_tensor *rms = cgraph->nodes[node_idx + 0];
|
|
|
+ const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
|
+ const ggml_tensor *rope = cgraph->nodes[node_idx + 2];
|
|
|
+
|
|
|
+ const int mode = ((const int32_t *) rope->op_params)[2];
|
|
|
+
|
|
|
+ // noncontig tensors aren't tested, and don't seem common in practice
|
|
|
+ if (!ggml_is_contiguous(rms) ||
|
|
|
+ !ggml_is_contiguous(mul) ||
|
|
|
+ !ggml_is_contiguous(rope)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ // only norm/neox are handled in the shader
|
|
|
+ if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ // shared memory size for passing data from mul->rope
|
|
|
+ if (mul->ne[0] > 1024) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ // must not overwrite srcs in a way that's not elementwise
|
|
|
+ ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
|
|
|
+ if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
|
|
|
+ ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
|
|
|
|
|
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
|
|
@@ -12552,12 +12750,20 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
|
|
if (num_adds) {
|
|
|
ctx->num_additional_fused_ops = num_adds - 1;
|
|
|
- } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
|
- ctx->num_additional_fused_ops = 1;
|
|
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
|
|
|
ctx->num_additional_fused_ops = 1;
|
|
|
} else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
|
|
|
ctx->num_additional_fused_ops = 1;
|
|
|
+ } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
|
|
|
+ ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
|
|
|
+ ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
|
|
|
+ ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
|
|
|
+ ctx->num_additional_fused_ops = 4;
|
|
|
+ } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
|
|
|
+ ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
|
|
|
+ ctx->num_additional_fused_ops = 2;
|
|
|
+ } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
|
+ ctx->num_additional_fused_ops = 1;
|
|
|
} else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
|
|
|
ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
|
|
|
ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
|
|
|
@@ -12790,14 +12996,34 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
}
|
|
|
if (ok) {
|
|
|
current_set.push_back(j);
|
|
|
+
|
|
|
+ int rope_idx = j;
|
|
|
+
|
|
|
+ // When we've found RMS_NORM + MUL, try to find a ROPE that uses it
|
|
|
+ if (j > 0 &&
|
|
|
+ graph->nodes[j]->op == GGML_OP_MUL &&
|
|
|
+ graph->nodes[j-1]->op == GGML_OP_RMS_NORM) {
|
|
|
+ for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
|
|
|
+ if (graph->nodes[k]->op == GGML_OP_ROPE &&
|
|
|
+ graph->nodes[k]->src[0] == graph->nodes[j] &&
|
|
|
+ // Check that other srcs are already valid
|
|
|
+ graph->nodes[k]->src[1]->op == GGML_OP_NONE &&
|
|
|
+ (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) {
|
|
|
+ rope_idx = k;
|
|
|
+ current_set.push_back(rope_idx);
|
|
|
+ used[rope_idx] = true;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
// Look for ROPE + VIEW + SET_ROWS and make them consecutive
|
|
|
- if (graph->nodes[j]->op == GGML_OP_ROPE) {
|
|
|
+ if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) {
|
|
|
int view_idx = -1;
|
|
|
int set_rows_idx = -1;
|
|
|
- for (int k = j+1; k < std::min(j + 10, graph->n_nodes); ++k) {
|
|
|
+ for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) {
|
|
|
if (view_idx == -1 &&
|
|
|
graph->nodes[k]->op == GGML_OP_VIEW &&
|
|
|
- graph->nodes[k]->src[0] == graph->nodes[j]) {
|
|
|
+ graph->nodes[k]->src[0] == graph->nodes[rope_idx]) {
|
|
|
view_idx = k;
|
|
|
continue;
|
|
|
}
|