Quellcode durchsuchen

vulkan: add LOG operation support for F32 and F16 (#17183)

* vulkan: add LOG operation support for F32 and F16

Part of #14909.

* vulkan: Fix LOG operation types

* docs: Update operation support documentation for Vulkan LOG operation

* vulkan: fix log_f16 shader

* docs: restore missing LOG test cases and regenerate ops.md
Pavels Zaicenkovs vor 2 Monaten
Ursprung
Commit
dbed61294a

+ 1 - 1
docs/ops.md

@@ -63,7 +63,7 @@ Legend:
 |                        IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ |
 |                          L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                       LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | 🟡 | ❌ |
-|                              LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 |  | ❌ |
+|                              LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | 🟡 |  | ❌ |
 |                             MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
 |                              MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
 |                          MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |

+ 4 - 4
docs/ops/Vulkan.csv

@@ -8627,7 +8627,7 @@
 "Vulkan0","ADD_ID","type_a=f32,type_b=f32,n_embd=129,n_experts=8,n_experts_used=4,n_token=129","support","1","yes","Vulkan"
 "Vulkan0","SQR","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan"
 "Vulkan0","SQRT","type=f16,ne=[10,3,3,2]","support","0","no","Vulkan"
-"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan"
+"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","1","yes","Vulkan"
 "Vulkan0","SIN","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
 "Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
 "Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
@@ -8638,7 +8638,7 @@
 "Vulkan0","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
 "Vulkan0","SQR","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
 "Vulkan0","SQRT","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
-"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
+"Vulkan0","LOG","type=f16,ne=[7,1,5,3]","support","1","yes","Vulkan"
 "Vulkan0","SIN","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
 "Vulkan0","COS","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
 "Vulkan0","CLAMP","type=f16,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
@@ -8649,7 +8649,7 @@
 "Vulkan0","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","Vulkan"
 "Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
 "Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","1","yes","Vulkan"
-"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
+"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
 "Vulkan0","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
 "Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
 "Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
@@ -8660,7 +8660,7 @@
 "Vulkan0","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","Vulkan"
 "Vulkan0","SQR","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
 "Vulkan0","SQRT","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
-"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","0","no","Vulkan"
+"Vulkan0","LOG","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
 "Vulkan0","SIN","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
 "Vulkan0","COS","type=f32,ne=[7,1,5,3]","support","1","yes","Vulkan"
 "Vulkan0","CLAMP","type=f32,ne=[7,1,5,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"

+ 25 - 0
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -629,6 +629,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_sqrt_f32;
     vk_pipeline pipeline_sin_f32;
     vk_pipeline pipeline_cos_f32;
+    vk_pipeline pipeline_log[2];
     vk_pipeline pipeline_clamp_f32;
     vk_pipeline pipeline_pad_f32;
     vk_pipeline pipeline_roll_f32;
@@ -3792,6 +3793,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
@@ -8126,6 +8129,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_cos_f32;
         }
         return nullptr;
+    case GGML_OP_LOG:
+        if (src0->type == dst->type &&
+            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
+            return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
+        }
+        return nullptr;
     case GGML_OP_CLAMP:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_clamp_f32;
@@ -8534,6 +8543,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_SQRT:
     case GGML_OP_SIN:
     case GGML_OP_COS:
+    case GGML_OP_LOG:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_REPEAT:
@@ -8806,6 +8816,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_SQRT:
     case GGML_OP_SIN:
     case GGML_OP_COS:
+    case GGML_OP_LOG:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_ROLL:
@@ -9414,6 +9425,10 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
     ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst));
 }
 
+static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
+}
+
 static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
     p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -11209,6 +11224,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_SQRT:
     case GGML_OP_SIN:
     case GGML_OP_COS:
+    case GGML_OP_LOG:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_ROLL:
@@ -11433,6 +11449,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_COS:
         ggml_vk_cos(ctx, compute_ctx, src0, node);
 
+        break;
+    case GGML_OP_LOG:
+        ggml_vk_log(ctx, compute_ctx, src0, node);
+
         break;
     case GGML_OP_CLAMP:
         ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -11703,6 +11723,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
     case GGML_OP_SQRT:
     case GGML_OP_SIN:
     case GGML_OP_COS:
+    case GGML_OP_LOG:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_ROLL:
@@ -13664,6 +13685,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_OPT_STEP_ADAMW:
         case GGML_OP_OPT_STEP_SGD:
             return op->src[0]->type == GGML_TYPE_F32;
+        case GGML_OP_LOG:
+            return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_ARGSORT:
             return op->ne[0] <= max_argsort_cols;
         case GGML_OP_UPSCALE:
@@ -14159,6 +14182,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_COS) {
             tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
+        } else if (tensor->op == GGML_OP_LOG) {
+            tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_CLAMP) {
             const float * params = (const float *)tensor->op_params;
             tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);

+ 17 - 0
ggml/src/ggml-vulkan/vulkan-shaders/log.comp

@@ -0,0 +1,17 @@
+#version 450
+
+#include "types.glsl"
+#include "generic_unary_head.glsl"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
+    data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));
+}

+ 3 - 0
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -802,6 +802,9 @@ void process_shaders() {
 
     string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
+    string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+    string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
+
     string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
     string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});