소스 검색

vulkan: support GGML_OP_DIAG (#17893)

Jeff Bolz 1 개월 전
부모
커밋
3229a23fa6
3개의 변경된 파일55개의 추가작업 그리고 0개의 파일을 삭제
  1. 24 0
      ggml/src/ggml-vulkan/ggml-vulkan.cpp
  2. 29 0
      ggml/src/ggml-vulkan/vulkan-shaders/diag.comp
  3. 2 0
      ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

+ 24 - 0
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -659,6 +659,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_cos_f32;
     vk_pipeline pipeline_log[2];
     vk_pipeline pipeline_tri[2];
+    vk_pipeline pipeline_diag[2];
     vk_pipeline pipeline_clamp_f32;
     vk_pipeline pipeline_pad_f32;
     vk_pipeline pipeline_roll_f32;
@@ -3924,6 +3925,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
+    ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -8416,6 +8420,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
         }
         return nullptr;
+    case GGML_OP_DIAG:
+        if (src0->type == dst->type &&
+            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
+            return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
+        }
+        return nullptr;
     case GGML_OP_CLAMP:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_clamp_f32;
@@ -9109,6 +9119,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_COS:
     case GGML_OP_LOG:
     case GGML_OP_TRI:
+    case GGML_OP_DIAG:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_ROLL:
@@ -9796,6 +9807,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const
     ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
 }
 
+static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
+
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
+}
+
 static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
     p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11924,6 +11941,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_TRI:
         ggml_vk_tri(ctx, compute_ctx, src0, node);
 
+        break;
+    case GGML_OP_DIAG:
+        ggml_vk_diag(ctx, compute_ctx, src0, node);
+
         break;
     case GGML_OP_CLAMP:
         ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -14067,6 +14088,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_LOG:
         case GGML_OP_TRI:
+        case GGML_OP_DIAG:
             return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
                    op->type == op->src[0]->type;
         case GGML_OP_ARGSORT:
@@ -14657,6 +14679,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_TRI) {
             tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
+        } else if (tensor->op == GGML_OP_DIAG) {
+            tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_CLAMP) {
             const float * params = (const float *)tensor->op_params;
             tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);

+ 29 - 0
ggml/src/ggml-vulkan/vulkan-shaders/diag.comp

@@ -0,0 +1,29 @@
+#version 450
+
+#include "rte.glsl"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
+    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+    const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
+    const uint i12_offset = i12*p.ne11*p.ne10;
+    const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
+    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+
+    if (i10 == i11) {
+        const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]);
+        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
+    } else {
+        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
+    }
+}

+ 2 - 0
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -854,6 +854,8 @@ void process_shaders() {
 
     string_to_spv("tri_f16",        "tri.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("tri_f32",        "tri.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("diag_f16",       "diag.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("diag_f32",       "diag.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 
     string_to_spv("softplus_f16",   "softplus.comp",    {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("softplus_f32",   "softplus.comp",    {{"A_TYPE", "float"},       {"D_TYPE", "float"}});