|
|
@@ -583,6 +583,7 @@ struct vk_device_struct {
|
|
|
bool disable_fusion;
|
|
|
bool disable_host_visible_vidmem;
|
|
|
bool allow_sysmem_fallback;
|
|
|
+ bool disable_optimize_graph;
|
|
|
|
|
|
#ifdef GGML_VULKAN_MEMORY_DEBUG
|
|
|
std::unique_ptr<vk_memory_logger> memory_logger;
|
|
|
@@ -3592,6 +3593,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
|
|
|
device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
|
|
|
|
|
|
+ const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH");
|
|
|
+ device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr;
|
|
|
+
|
|
|
bool fp16_storage = false;
|
|
|
bool fp16_compute = false;
|
|
|
bool maintenance4_support = false;
|
|
|
@@ -11853,6 +11857,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
UNUSED(backend);
|
|
|
}
|
|
|
|
|
|
+// Sort the graph for improved parallelism.
|
|
|
+static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph)
|
|
|
+{
|
|
|
+ VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)");
|
|
|
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
|
|
|
+
|
|
|
+ if (ctx->device->disable_optimize_graph) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ auto const &is_empty = [](ggml_tensor * node) -> bool {
|
|
|
+ return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
|
|
|
+ };
|
|
|
+
|
|
|
+ auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {
|
|
|
+ for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
|
|
|
+ if (dst->src[s] == src) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // implicit dependency if they view the same tensor
|
|
|
+ const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;
|
|
|
+ const ggml_tensor *src2 = src->view_src ? src->view_src : src;
|
|
|
+ if (dst2 == src2) {
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ return false;
|
|
|
+ };
|
|
|
+
|
|
|
+ // This function tries to reorder the graph to allow nodes to run in parallel.
|
|
|
+ // This helps with small batches, but for large batches its a slowdown, probably
|
|
|
+ // due to cache contention. So only reorder if the majority of nodes have few rows.
|
|
|
+ int num_small_nodes = 0;
|
|
|
+ int num_counted_nodes = 0;
|
|
|
+ for (int i = 0; i < graph->n_nodes; ++i) {
|
|
|
+ if (!is_empty(graph->nodes[i]) &&
|
|
|
+ graph->nodes[i]->op != GGML_OP_SET_ROWS) {
|
|
|
+ if (ggml_nrows(graph->nodes[i]) <= 8) {
|
|
|
+ num_small_nodes++;
|
|
|
+ }
|
|
|
+ num_counted_nodes++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (num_small_nodes < num_counted_nodes / 2) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ std::vector<ggml_tensor *> new_order;
|
|
|
+ std::vector<bool> used(graph->n_nodes, false);
|
|
|
+ int first_unused = 0;
|
|
|
+ while (first_unused < graph->n_nodes) {
|
|
|
+ std::vector<int> current_set;
|
|
|
+
|
|
|
+ // First, grab the next unused node.
|
|
|
+ current_set.push_back(first_unused);
|
|
|
+
|
|
|
+ // Loop through the next N nodes. Grab any that don't depend on other nodes that
|
|
|
+ // haven't already been run. Nodes that have already been run have used[i] set
|
|
|
+ // to true. Allow nodes that depend on the previous node if it's a fusion pattern
|
|
|
+ // that we support (e.g. RMS_NORM + MUL).
|
|
|
+ // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes.
|
|
|
+ // The goal is to not interleave real and view nodes in a way that breaks fusion.
|
|
|
+ const int NUM_TO_CHECK = 20;
|
|
|
+ for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
|
|
|
+ if (used[j]) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (is_empty(graph->nodes[j])) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ bool ok = true;
|
|
|
+ for (int c = first_unused; c < j; ++c) {
|
|
|
+ if (!used[c] &&
|
|
|
+ is_src_of(graph->nodes[j], graph->nodes[c]) &&
|
|
|
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) {
|
|
|
+ ok = false;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (ok) {
|
|
|
+ current_set.push_back(j);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Second pass grabs view nodes.
|
|
|
+ // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).
|
|
|
+ if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
|
|
|
+ for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
|
|
|
+ if (used[j]) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (!is_empty(graph->nodes[j])) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ bool ok = true;
|
|
|
+ for (int c = first_unused; c < j; ++c) {
|
|
|
+ bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();
|
|
|
+ // skip views whose srcs haven't been processed.
|
|
|
+ if (!used[c] &&
|
|
|
+ is_src_of(graph->nodes[j], graph->nodes[c]) &&
|
|
|
+ !c_in_current_set) {
|
|
|
+ ok = false;
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (ok) {
|
|
|
+ current_set.push_back(j);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ // Push the current set into new_order
|
|
|
+ for (auto c : current_set) {
|
|
|
+ new_order.push_back(graph->nodes[c]);
|
|
|
+ used[c] = true;
|
|
|
+ }
|
|
|
+ while (first_unused < graph->n_nodes && used[first_unused]) {
|
|
|
+ first_unused++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Replace the graph with the new order.
|
|
|
+ for (int i = 0; i < graph->n_nodes; ++i) {
|
|
|
+ graph->nodes[i] = new_order[i];
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
// TODO: enable async and synchronize
|
|
|
static ggml_backend_i ggml_backend_vk_interface = {
|
|
|
/* .get_name = */ ggml_backend_vk_name,
|
|
|
@@ -11868,6 +11997,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
|
|
|
/* .graph_compute = */ ggml_backend_vk_graph_compute,
|
|
|
/* .event_record = */ NULL,
|
|
|
/* .event_wait = */ NULL,
|
|
|
+ /* .optimize_graph = */ ggml_vk_optimize_graph,
|
|
|
};
|
|
|
|
|
|
static ggml_guid_t ggml_backend_vk_guid() {
|