Ver código fonte

vulkan: implement ABS and NEG (#17245)

* docs: update Vulkan ops

* vulkan: add NEG op

* vulkan: add ABS op

---------

Signed-off-by: Giuseppe Scrivano <gscrivan@redhat.com>
Giuseppe Scrivano 2 meses atrás
pai
commit
1568d13c2c

+ 17 - 17
docs/ops.md

@@ -14,14 +14,14 @@ Legend:
 
 | Operation | BLAS | CANN | CPU | CUDA | Metal | OpenCL | SYCL | Vulkan | zDNN |
 |-----------|------|------|------|------|------|------|------|------|------|
-|                              ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 |  | ❌ |
+|                              ABS | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
 |                              ACC | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                              ADD | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
 |                             ADD1 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
-|                           ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |  | ❌ |
+|                           ADD_ID | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |  | ❌ |
 |                           ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
 |                           ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
-|                          ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |  | ❌ |
+|                          ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ❌ |
 |                             CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
 |                            CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
 |                           CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ |
@@ -30,7 +30,7 @@ Legend:
 |                       CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
 |                          CONV_3D | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
-|                CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |  | ❌ |
+|                CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |  | ❌ |
 |                              COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
 |                      COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
 |                              CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ |
@@ -41,7 +41,7 @@ Legend:
 |                              DIV | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
 |                              DUP | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
 |                              ELU | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
-|                              EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 |  | ❌ |
+|                              EXP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
 |                            EXPM1 | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                             FILL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                   FLASH_ATTN_EXT | ❌ | 🟡 | ✅ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ |
@@ -57,22 +57,22 @@ Legend:
 |                    GET_ROWS_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                       GROUP_NORM | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
 |               GROUP_NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
-|                      HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 |  | ❌ |
-|                        HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 |  | ❌ |
+|                      HARDSIGMOID | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
+|                        HARDSWISH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
 |                           IM2COL | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
-|                        IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |  | ❌ |
+|                        IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |  | ❌ |
 |                          L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
-|                       LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |  | ❌ |
+|                       LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
 |                              LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ |
-|                             MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |  | ❌ |
+|                             MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ |  | ❌ |
 |                              MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
 |                          MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
 |                       MUL_MAT_ID | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ |
-|                              NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 |  | ❌ |
+|                              NEG | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | 🟡 | ❌ |
 |                             NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
 |                     NORM_MUL_ADD | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
 |                   OPT_STEP_ADAMW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
-|                     OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |  | ❌ |
+|                     OPT_STEP_SGD | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |  | ❌ |
 |                         OUT_PROD | 🟡 | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ❌ | ❌ |
 |                              PAD | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
 |                   PAD_REFLECT_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
@@ -83,7 +83,7 @@ Legend:
 |                      REPEAT_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
 |                         RMS_NORM | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ |
 |                    RMS_NORM_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
-|                 RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ |  | ❌ |
+|                 RMS_NORM_MUL_ADD | ❌ | ✅ | ❌ | ❌ | ✅ | ✅ | ✅ |  | ❌ |
 |                             ROLL | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ✅ | ✅ | ❌ |
 |                             ROPE | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
 |                        ROPE_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
@@ -104,15 +104,15 @@ Legend:
 |                    SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ |
 |                        SOLVE_TRI | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
 |                              SQR | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
-|                             SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ |  | ❌ |
+|                             SQRT | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ |
 |                         SSM_CONV | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
-|                         SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |  | ❌ |
+|                         SSM_SCAN | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 | ❌ |
 |                             STEP | ❌ | ✅ | ✅ | 🟡 | 🟡 | ❌ | 🟡 | ❌ | ❌ |
 |                              SUB | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
-|                              SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 |  | ❌ |
+|                              SUM | ❌ | ✅ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ |
 |                         SUM_ROWS | ❌ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ |
 |                           SWIGLU | ❌ | ✅ | ✅ | ✅ | 🟡 | ✅ | ✅ | 🟡 | ❌ |
-|                       SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |  | ❌ |
+|                       SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | 🟡 | ❌ |
 |                             TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | 🟡 | ❌ |
 |               TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
 |                         TOPK_MOE | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |

Diferenças do arquivo suprimidas por serem muito extensas
+ 2535 - 402
docs/ops/Vulkan.csv


+ 22 - 0
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -656,10 +656,12 @@ struct vk_device_struct {
     vk_pipeline pipeline_gelu_quick[2];
     vk_pipeline pipeline_silu[2];
     vk_pipeline pipeline_relu[2];
+    vk_pipeline pipeline_neg[2];
     vk_pipeline pipeline_tanh[2];
     vk_pipeline pipeline_sigmoid[2];
     vk_pipeline pipeline_hardsigmoid[2];
     vk_pipeline pipeline_hardswish[2];
+    vk_pipeline pipeline_abs[2];
 
     vk_pipeline pipeline_geglu[2];
     vk_pipeline pipeline_reglu[2];
@@ -3804,10 +3806,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
     CREATE_UNARY(gelu_quick)
     CREATE_UNARY(silu)
     CREATE_UNARY(relu)
+    CREATE_UNARY(neg)
     CREATE_UNARY(tanh)
     CREATE_UNARY(sigmoid)
     CREATE_UNARY(hardsigmoid)
     CREATE_UNARY(hardswish)
+    CREATE_UNARY(abs)
 #undef CREATE_UNARY
 
 #define CREATE_UNARY_RTE(name)  \
@@ -8170,6 +8174,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_RELU:
                 return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
+            case GGML_UNARY_OP_NEG:
+                return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_TANH:
                 return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_SIGMOID:
@@ -8178,6 +8184,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_HARDSWISH:
                 return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
+            case GGML_UNARY_OP_ABS:
+                return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
             default:
                 break;
         }
@@ -11106,10 +11114,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_UNARY_OP_GELU_ERF:
         case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
+        case GGML_UNARY_OP_NEG:
         case GGML_UNARY_OP_TANH:
         case GGML_UNARY_OP_SIGMOID:
         case GGML_UNARY_OP_HARDSIGMOID:
         case GGML_UNARY_OP_HARDSWISH:
+        case GGML_UNARY_OP_ABS:
             break;
         default:
             return false;
@@ -11436,10 +11446,12 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_UNARY_OP_GELU_ERF:
         case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
+        case GGML_UNARY_OP_NEG:
         case GGML_UNARY_OP_TANH:
         case GGML_UNARY_OP_SIGMOID:
         case GGML_UNARY_OP_HARDSIGMOID:
         case GGML_UNARY_OP_HARDSWISH:
+        case GGML_UNARY_OP_ABS:
             ggml_vk_unary(ctx, compute_ctx, src0, node);
             break;
         default:
@@ -11706,10 +11718,12 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
         case GGML_UNARY_OP_GELU_ERF:
         case GGML_UNARY_OP_GELU_QUICK:
         case GGML_UNARY_OP_RELU:
+        case GGML_UNARY_OP_NEG:
         case GGML_UNARY_OP_TANH:
         case GGML_UNARY_OP_SIGMOID:
         case GGML_UNARY_OP_HARDSIGMOID:
         case GGML_UNARY_OP_HARDSWISH:
+        case GGML_UNARY_OP_ABS:
             buf = tensor->buffer;
             break;
         default:
@@ -13235,10 +13249,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_NEG:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_SIGMOID:
                 case GGML_UNARY_OP_HARDSIGMOID:
                 case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_ABS:
                     return ggml_is_contiguous(op->src[0]) &&
                            (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
                            (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14116,6 +14132,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             case GGML_UNARY_OP_RELU:
                 tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
                 break;
+            case GGML_UNARY_OP_NEG:
+                tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
+                break;
             case GGML_UNARY_OP_TANH:
                 tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
                 break;
@@ -14128,6 +14147,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             case GGML_UNARY_OP_HARDSWISH:
                 tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
                 break;
+            case GGML_UNARY_OP_ABS:
+                tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
+                break;
             default:
                 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
                 GGML_ABORT("fatal error");

+ 21 - 0
ggml/src/ggml-vulkan/vulkan-shaders/abs.comp

@@ -0,0 +1,21 @@
+#version 450
+
+#include "generic_head.glsl"
+#include "types.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    data_d[i] = D_TYPE(abs(float(data_a[i])));
+}

+ 20 - 0
ggml/src/ggml-vulkan/vulkan-shaders/neg.comp

@@ -0,0 +1,20 @@
+#version 450
+
+#include "generic_head.glsl"
+#include "types.glsl"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+    data_d[i] = D_TYPE(-float(data_a[i]));
+}

+ 4 - 0
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -827,6 +827,8 @@ void process_shaders() {
     string_to_spv("silu_f32",       "silu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("relu_f16",       "relu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("relu_f32",       "relu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("neg_f16",        "neg.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("neg_f32",        "neg.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("tanh_f16",       "tanh.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("tanh_f32",       "tanh.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("sigmoid_f16",    "sigmoid.comp",     {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
@@ -835,6 +837,8 @@ void process_shaders() {
     string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
     string_to_spv("hardswish_f16",  "hardswish.comp",   {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("hardswish_f32",  "hardswish.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("abs_f16",        "abs.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("abs_f32",        "abs.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 
     for (auto rte : {false, true}) {
         std::string suffix = rte ? "_rte" : "";

Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff