|
|
@@ -33,6 +33,7 @@
|
|
|
#include "ggml-cuda/rope.cuh"
|
|
|
#include "ggml-cuda/roll.cuh"
|
|
|
#include "ggml-cuda/scale.cuh"
|
|
|
+#include "ggml-cuda/softcap.cuh"
|
|
|
#include "ggml-cuda/softmax.cuh"
|
|
|
#include "ggml-cuda/ssm-conv.cuh"
|
|
|
#include "ggml-cuda/ssm-scan.cuh"
|
|
|
@@ -2770,7 +2771,12 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
-static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
|
|
|
+static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) {
|
|
|
+#ifndef NDEBUG
|
|
|
+ const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
|
|
|
+ GGML_ASSERT(unary_ops.size() == num_unary);
|
|
|
+#endif
|
|
|
+
|
|
|
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
|
|
|
return false;
|
|
|
}
|
|
|
@@ -2798,9 +2804,32 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
|
|
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
|
|
|
return false;
|
|
|
}
|
|
|
+
|
|
|
+ return true;
|
|
|
}
|
|
|
|
|
|
- return true;
|
|
|
+ if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
|
|
|
+ && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
|
|
|
+ const ggml_tensor *scale = cgraph->nodes[node_idx];
|
|
|
+ const ggml_tensor *tanh = cgraph->nodes[node_idx+1];
|
|
|
+ const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
|
|
|
+
|
|
|
+ GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT(scale->type == GGML_TYPE_F32);
|
|
|
+
|
|
|
+ if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ // Check for bias
|
|
|
+ if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ return false;
|
|
|
}
|
|
|
|
|
|
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
|
|
|
@@ -2821,10 +2850,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|
|
}
|
|
|
|
|
|
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
|
|
- if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
|
- ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
|
|
|
- i++;
|
|
|
- continue;
|
|
|
+ if (!disable_fusion) {
|
|
|
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
|
|
|
+ ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
|
|
|
+ i++;
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
|
|
|
+ i += 2;
|
|
|
+ ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
}
|
|
|
#ifndef NDEBUG
|
|
|
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
|