Просмотр исходного кода

vulkan: remove a couple unnecessary switches (#17419)

Jeff Bolz 1 месяц назад
Родитель
Сommit
54d83bbe85
1 измененных файлов с 17 добавлено и 263 удалено
  1. 17 263
      ggml/src/ggml-vulkan/ggml-vulkan.cpp

+ 17 - 263
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -11381,13 +11381,13 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
     }
 }
 
-static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
+static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
 
 // Returns true if node has enqueued work into the queue, false otherwise
 // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
 static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
     ggml_tensor * node = cgraph->nodes[node_idx];
-    if (ggml_is_empty(node) || !node->buffer) {
+    if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
         return false;
     }
 
@@ -11399,132 +11399,19 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     ggml_tensor * src2 = node->src[2];
     ggml_tensor * src3 = node->src[3];
 
-    switch (node->op) {
-    // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
-    case GGML_OP_RESHAPE:
-    case GGML_OP_VIEW:
-    case GGML_OP_PERMUTE:
-    case GGML_OP_TRANSPOSE:
-    case GGML_OP_NONE:
-        return false;
-    case GGML_OP_UNARY:
-        switch (ggml_get_unary_op(node)) {
-        case GGML_UNARY_OP_EXP:
-        case GGML_UNARY_OP_SILU:
-        case GGML_UNARY_OP_GELU:
-        case GGML_UNARY_OP_GELU_ERF:
-        case GGML_UNARY_OP_GELU_QUICK:
-        case GGML_UNARY_OP_RELU:
-        case GGML_UNARY_OP_NEG:
-        case GGML_UNARY_OP_TANH:
-        case GGML_UNARY_OP_SIGMOID:
-        case GGML_UNARY_OP_HARDSIGMOID:
-        case GGML_UNARY_OP_HARDSWISH:
-        case GGML_UNARY_OP_ABS:
-        case GGML_UNARY_OP_SOFTPLUS:
-        case GGML_UNARY_OP_STEP:
-        case GGML_UNARY_OP_ROUND:
-        case GGML_UNARY_OP_CEIL:
-        case GGML_UNARY_OP_FLOOR:
-        case GGML_UNARY_OP_TRUNC:
-            break;
-        default:
-            return false;
-        }
-        break;
-    case GGML_OP_GLU:
-        switch (ggml_get_glu_op(node)) {
-        case GGML_GLU_OP_GEGLU:
-        case GGML_GLU_OP_REGLU:
-        case GGML_GLU_OP_SWIGLU:
-        case GGML_GLU_OP_SWIGLU_OAI:
-        case GGML_GLU_OP_GEGLU_ERF:
-        case GGML_GLU_OP_GEGLU_QUICK:
-            break;
-        default:
-            return false;
-        }
-        break;
-    case GGML_OP_ADD:
-        {
-            int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
-            if (next_node_idx < cgraph->n_nodes &&
-                cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
-                cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
-                ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
-                ctx->device->add_rms_fusion) {
-                uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
-                ctx->do_add_rms_partials_offset_calculation = true;
-                if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
-                    ctx->do_add_rms_partials = true;
-                }
+    if (node->op == GGML_OP_ADD) {
+        int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
+        if (next_node_idx < cgraph->n_nodes &&
+            cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
+            cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
+            ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
+            ctx->device->add_rms_fusion) {
+            uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
+            ctx->do_add_rms_partials_offset_calculation = true;
+            if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
+                ctx->do_add_rms_partials = true;
             }
-        } break;
-    case GGML_OP_REPEAT:
-    case GGML_OP_REPEAT_BACK:
-    case GGML_OP_GET_ROWS:
-    case GGML_OP_ADD_ID:
-    case GGML_OP_ACC:
-    case GGML_OP_SUB:
-    case GGML_OP_MUL:
-    case GGML_OP_DIV:
-    case GGML_OP_ADD1:
-    case GGML_OP_ARANGE:
-    case GGML_OP_FILL:
-    case GGML_OP_CONCAT:
-    case GGML_OP_UPSCALE:
-    case GGML_OP_SCALE:
-    case GGML_OP_SQR:
-    case GGML_OP_SQRT:
-    case GGML_OP_SIN:
-    case GGML_OP_COS:
-    case GGML_OP_LOG:
-    case GGML_OP_CLAMP:
-    case GGML_OP_PAD:
-    case GGML_OP_ROLL:
-    case GGML_OP_CPY:
-    case GGML_OP_SET_ROWS:
-    case GGML_OP_CONT:
-    case GGML_OP_DUP:
-    case GGML_OP_SILU_BACK:
-    case GGML_OP_NORM:
-    case GGML_OP_GROUP_NORM:
-    case GGML_OP_RMS_NORM:
-    case GGML_OP_RMS_NORM_BACK:
-    case GGML_OP_L2_NORM:
-    case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_SOFT_MAX:
-    case GGML_OP_SOFT_MAX_BACK:
-    case GGML_OP_ROPE:
-    case GGML_OP_ROPE_BACK:
-    case GGML_OP_MUL_MAT:
-    case GGML_OP_MUL_MAT_ID:
-    case GGML_OP_ARGSORT:
-    case GGML_OP_SUM:
-    case GGML_OP_SUM_ROWS:
-    case GGML_OP_MEAN:
-    case GGML_OP_ARGMAX:
-    case GGML_OP_COUNT_EQUAL:
-    case GGML_OP_IM2COL:
-    case GGML_OP_IM2COL_3D:
-    case GGML_OP_TIMESTEP_EMBEDDING:
-    case GGML_OP_CONV_TRANSPOSE_1D:
-    case GGML_OP_POOL_2D:
-    case GGML_OP_CONV_2D:
-    case GGML_OP_CONV_TRANSPOSE_2D:
-    case GGML_OP_CONV_2D_DW:
-    case GGML_OP_RWKV_WKV6:
-    case GGML_OP_RWKV_WKV7:
-    case GGML_OP_SSM_SCAN:
-    case GGML_OP_SSM_CONV:
-    case GGML_OP_LEAKY_RELU:
-    case GGML_OP_FLASH_ATTN_EXT:
-    case GGML_OP_OPT_STEP_ADAMW:
-    case GGML_OP_OPT_STEP_SGD:
-        break;
-    default:
-        std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
-        GGML_ABORT("fatal error");
+        }
     }
 
     vk_context compute_ctx;
@@ -11961,145 +11848,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         ctx->compute_ctx.reset();
 
-        bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
-        if (!ok) {
-            if (node->op == GGML_OP_UNARY) {
-                std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
-            } else if (node->op == GGML_OP_GLU) {
-                std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
-            } else {
-                std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
-            }
-        }
-
+        ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
     }
     return true;
 }
 
-static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
+static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
     GGML_UNUSED(cgraph);
-    ggml_backend_buffer * buf = nullptr;
-
-    switch (tensor->op) {
-    case GGML_OP_ADD:
-    case GGML_OP_ACC:
-    case GGML_OP_GET_ROWS:
-    case GGML_OP_SUB:
-    case GGML_OP_MUL:
-    case GGML_OP_DIV:
-    case GGML_OP_ADD1:
-    case GGML_OP_ARANGE:
-    case GGML_OP_FILL:
-    case GGML_OP_ADD_ID:
-    case GGML_OP_CONCAT:
-    case GGML_OP_UPSCALE:
-    case GGML_OP_SCALE:
-    case GGML_OP_SQR:
-    case GGML_OP_SQRT:
-    case GGML_OP_SIN:
-    case GGML_OP_COS:
-    case GGML_OP_LOG:
-    case GGML_OP_CLAMP:
-    case GGML_OP_PAD:
-    case GGML_OP_ROLL:
-    case GGML_OP_CPY:
-    case GGML_OP_SET_ROWS:
-    case GGML_OP_CONT:
-    case GGML_OP_DUP:
-    case GGML_OP_SILU_BACK:
-    case GGML_OP_NORM:
-    case GGML_OP_GROUP_NORM:
-    case GGML_OP_RMS_NORM:
-    case GGML_OP_RMS_NORM_BACK:
-    case GGML_OP_L2_NORM:
-    case GGML_OP_DIAG_MASK_INF:
-    case GGML_OP_SOFT_MAX:
-    case GGML_OP_SOFT_MAX_BACK:
-    case GGML_OP_ROPE:
-    case GGML_OP_ROPE_BACK:
-    case GGML_OP_RESHAPE:
-    case GGML_OP_VIEW:
-    case GGML_OP_PERMUTE:
-    case GGML_OP_TRANSPOSE:
-    case GGML_OP_NONE:
-    case GGML_OP_ARGSORT:
-    case GGML_OP_SUM:
-    case GGML_OP_SUM_ROWS:
-    case GGML_OP_MEAN:
-    case GGML_OP_ARGMAX:
-    case GGML_OP_COUNT_EQUAL:
-    case GGML_OP_IM2COL:
-    case GGML_OP_IM2COL_3D:
-    case GGML_OP_TIMESTEP_EMBEDDING:
-    case GGML_OP_CONV_TRANSPOSE_1D:
-    case GGML_OP_POOL_2D:
-    case GGML_OP_CONV_2D:
-    case GGML_OP_CONV_TRANSPOSE_2D:
-    case GGML_OP_CONV_2D_DW:
-    case GGML_OP_RWKV_WKV6:
-    case GGML_OP_RWKV_WKV7:
-    case GGML_OP_SSM_SCAN:
-    case GGML_OP_SSM_CONV:
-    case GGML_OP_LEAKY_RELU:
-    case GGML_OP_REPEAT:
-    case GGML_OP_REPEAT_BACK:
-    case GGML_OP_OPT_STEP_ADAMW:
-    case GGML_OP_OPT_STEP_SGD:
-        buf = tensor->buffer;
-        break;
-    case GGML_OP_UNARY:
-        switch (ggml_get_unary_op(tensor)) {
-        case GGML_UNARY_OP_EXP:
-        case GGML_UNARY_OP_SILU:
-        case GGML_UNARY_OP_GELU:
-        case GGML_UNARY_OP_GELU_ERF:
-        case GGML_UNARY_OP_GELU_QUICK:
-        case GGML_UNARY_OP_RELU:
-        case GGML_UNARY_OP_NEG:
-        case GGML_UNARY_OP_TANH:
-        case GGML_UNARY_OP_SIGMOID:
-        case GGML_UNARY_OP_HARDSIGMOID:
-        case GGML_UNARY_OP_HARDSWISH:
-        case GGML_UNARY_OP_ABS:
-        case GGML_UNARY_OP_SOFTPLUS:
-        case GGML_UNARY_OP_STEP:
-        case GGML_UNARY_OP_ROUND:
-        case GGML_UNARY_OP_CEIL:
-        case GGML_UNARY_OP_FLOOR:
-        case GGML_UNARY_OP_TRUNC:
-            buf = tensor->buffer;
-            break;
-        default:
-            return false;
-        }
-        break;
-    case GGML_OP_GLU:
-        switch (ggml_get_glu_op(tensor)) {
-        case GGML_GLU_OP_GEGLU:
-        case GGML_GLU_OP_REGLU:
-        case GGML_GLU_OP_SWIGLU:
-        case GGML_GLU_OP_SWIGLU_OAI:
-        case GGML_GLU_OP_GEGLU_ERF:
-        case GGML_GLU_OP_GEGLU_QUICK:
-            buf = tensor->buffer;
-            break;
-        default:
-            return false;
-        }
-        break;
-    case GGML_OP_MUL_MAT:
-    case GGML_OP_MUL_MAT_ID:
-    case GGML_OP_FLASH_ATTN_EXT:
-        buf = tensor->buffer;
-
-        break;
-    default:
-        return false;
-    }
-
-    if (buf == nullptr) {
-        return false;
-    }
+    GGML_UNUSED(tensor);
 
     VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
 
@@ -12143,8 +11899,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
         subctx->out_memcpys.clear();
         subctx->memsets.clear();
     }
-
-    return true;
 }
 
 // Clean up after graph processing is done