Pārlūkot izejas kodu

vulkan : move contiguous checks to device_supports_op (#17490)

* vulkan : remove op_supports_incontiguous and add missing constraints in device_supports_op

* im2col: remove contraints on src0 (kernel input)
Acly 1 mēnesi atpakaļ
vecāks
revīzija
b78db3bd50
1 mainītis faili ar 35 papildinājumiem un 50 dzēšanām
  1. 35 50
      ggml/src/ggml-vulkan/ggml-vulkan.cpp

+ 35 - 50
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -8687,41 +8687,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
     GGML_UNUSED(src2);
 }
 
-static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
-    switch (op) {
-    case GGML_OP_CPY:
-    case GGML_OP_GET_ROWS:
-    case GGML_OP_ADD:
-    case GGML_OP_SUB:
-    case GGML_OP_MUL:
-    case GGML_OP_DIV:
-    case GGML_OP_ADD_ID:
-    case GGML_OP_CONCAT:
-    case GGML_OP_UPSCALE:
-    case GGML_OP_SQR:
-    case GGML_OP_SQRT:
-    case GGML_OP_SIN:
-    case GGML_OP_COS:
-    case GGML_OP_LOG:
-    case GGML_OP_CLAMP:
-    case GGML_OP_PAD:
-    case GGML_OP_REPEAT:
-    case GGML_OP_REPEAT_BACK:
-    case GGML_OP_ROPE:
-    case GGML_OP_RMS_NORM:
-    case GGML_OP_CONV_2D_DW:
-    case GGML_OP_IM2COL:
-    case GGML_OP_IM2COL_3D:
-    case GGML_OP_SET_ROWS:
-    case GGML_OP_SUM:
-    case GGML_OP_SUM_ROWS:
-    case GGML_OP_MEAN:
-        return true;
-    default:
-        return false;
-    }
-}
-
 template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
     const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
     const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@@ -8806,7 +8771,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
     std::cerr << "), " << ggml_op_name(op) << ")");
     GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))));  // NOLINT
-    GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0));  // NOLINT
     GGML_ASSERT(dst->buffer != nullptr);
     const uint64_t ne00 = src0->ne[0];
     const uint64_t ne01 = src0->ne[1];
@@ -8837,22 +8801,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
 
     ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 
-    const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
-
-    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous);
-    vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
-    vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
-    vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
-    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
+    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);
+    vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};
+    vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};
+    vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};
+    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
 
     // Compute misalignment offset for descriptors and store it in in push constants.
     init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
 
     std::array<uint32_t, 3> elements;
 
-    // Single call if dimension 2 is contiguous
-    GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
-
     switch (op) {
     case GGML_OP_NORM:
     case GGML_OP_RMS_NORM_BACK:
@@ -13876,15 +13835,17 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                    op->type == GGML_TYPE_F32;
         case GGML_OP_SILU_BACK:
         case GGML_OP_RMS_NORM_BACK:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_SIN:
         case GGML_OP_COS:
         case GGML_OP_CLAMP:
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_LEAKY_RELU:
         case GGML_OP_OPT_STEP_ADAMW:
         case GGML_OP_OPT_STEP_SGD:
-            return op->src[0]->type == GGML_TYPE_F32;
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_LOG:
             return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_ARGSORT:
@@ -13919,17 +13880,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return true;
         case GGML_OP_UPSCALE:
         case GGML_OP_ACC:
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_CONCAT:
+            return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
         case GGML_OP_ADD1:
+            return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)
+                || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
+                || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
         case GGML_OP_ARANGE:
         case GGML_OP_FILL:
+            return op->type == GGML_TYPE_F32;
         case GGML_OP_SCALE:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_PAD:
         case GGML_OP_ROLL:
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_DIAG_MASK_INF:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SOFT_MAX:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
+                && (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));
         case GGML_OP_SOFT_MAX_BACK:
-            return true;
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
+                && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
@@ -13944,15 +13917,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 return false;
             }
         case GGML_OP_ARGMAX:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_COUNT_EQUAL:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32
+                && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;
         case GGML_OP_IM2COL:
+            return ggml_is_contiguous(op->src[1])
+                && op->src[1]->type == GGML_TYPE_F32
+                && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
         case GGML_OP_IM2COL_3D:
+            return op->src[1]->type == GGML_TYPE_F32
+                && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
         case GGML_OP_TIMESTEP_EMBEDDING:
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_CONV_2D_DW:
+            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)
+                && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_POOL_2D:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
-            return true;
+            return true; // all inputs are contiguous, see ggml.c
         case GGML_OP_SSM_SCAN:
             {
                 for (int i = 0; i < 6; i++) {
@@ -13993,7 +13978,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 return true;
             }
         case GGML_OP_SSM_CONV:
-            return true;
+            return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_CONV_TRANSPOSE_1D:
             return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_CONV_2D: