Selaa lähdekoodia

vulkan: Implement GGML_OP_TRI (#17503)

* vulkan: Implement GGML_OP_TRI

* check types match
Jeff Bolz 1 kuukausi sitten
vanhempi
sitoutus
35cf8887e1

+ 27 - 1
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -649,6 +649,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_sin_f32;
     vk_pipeline pipeline_cos_f32;
     vk_pipeline pipeline_log[2];
+    vk_pipeline pipeline_tri[2];
     vk_pipeline pipeline_clamp_f32;
     vk_pipeline pipeline_pad_f32;
     vk_pipeline pipeline_roll_f32;
@@ -3876,6 +3877,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
         ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     }
 
+    ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -8290,6 +8294,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
         }
         return nullptr;
+    case GGML_OP_TRI:
+        if (src0->type == dst->type &&
+            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
+            return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
+        }
+        return nullptr;
     case GGML_OP_CLAMP:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_clamp_f32;
@@ -8991,6 +9001,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_SIN:
     case GGML_OP_COS:
     case GGML_OP_LOG:
+    case GGML_OP_TRI:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_ROLL:
@@ -9671,6 +9682,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const
     ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
 }
 
+static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
+    p.param1 = ggml_get_op_params_f32(dst, 0);
+
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
+}
+
 static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
     p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11794,6 +11812,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_LOG:
         ggml_vk_log(ctx, compute_ctx, src0, node);
 
+        break;
+    case GGML_OP_TRI:
+        ggml_vk_tri(ctx, compute_ctx, src0, node);
+
         break;
     case GGML_OP_CLAMP:
         ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -13919,7 +13941,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_OPT_STEP_SGD:
             return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_LOG:
-            return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
+        case GGML_OP_TRI:
+            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
+                   op->type == op->src[0]->type;
         case GGML_OP_ARGSORT:
             {
                 if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
@@ -14510,6 +14534,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_LOG) {
             tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
+        } else if (tensor->op == GGML_OP_TRI) {
+            tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
         } else if (tensor->op == GGML_OP_CLAMP) {
             const float * params = (const float *)tensor->op_params;
             tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);

+ 43 - 0
ggml/src/ggml-vulkan/vulkan-shaders/tri.comp

@@ -0,0 +1,43 @@
+#version 450
+
+#include "rte.glsl"
+#include "types.glsl"
+#include "generic_unary_head.glsl"
+
+#define GGML_TRI_TYPE_UPPER_DIAG 0
+#define GGML_TRI_TYPE_UPPER      1
+#define GGML_TRI_TYPE_LOWER_DIAG 2
+#define GGML_TRI_TYPE_LOWER      3
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
+    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
+    const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
+    const uint i02_offset = i02*p.ne01*p.ne00;
+    const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
+    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+
+    int param = floatBitsToInt(p.param1);
+    bool pass = false;
+    switch (param) {
+    case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
+    case GGML_TRI_TYPE_UPPER:      pass = i00 >  i01; break;
+    case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
+    case GGML_TRI_TYPE_LOWER:      pass = i00 <  i01; break;
+    }
+
+    if (pass) {
+        const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
+        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
+    } else {
+        data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
+    }
+}

+ 3 - 0
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -846,6 +846,9 @@ void process_shaders() {
     string_to_spv("abs_f16",        "abs.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("abs_f32",        "abs.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 
+    string_to_spv("tri_f16",        "tri.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("tri_f32",        "tri.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+
     string_to_spv("softplus_f16",   "softplus.comp",    {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("softplus_f32",   "softplus.comp",    {{"A_TYPE", "float"},       {"D_TYPE", "float"}});