1
0
Эх сурвалжийг харах

vulkan: implement several ops relevant for ggml_opt (#11769)

* vulkan: support memset_tensor

* vulkan: support GGML_OP_SUM

* vulkan: implement GGML_OP_ARGMAX

* vulkan: implement GGML_OP_SUB

* vulkan: implement GGML_OP_COUNT_EQUAL

* vulkan: implement GGML_OP_OPT_STEP_ADAMW

* vulkan: fix check_results RWKV_WKV6 crash and memory leaks

* vulkan: implement GGML_OP_REPEAT_BACK

* tests: remove invalid test-backend-ops REPEAT_BACK tests

* vulkan: fix COUNT_EQUAL memset using a fillBuffer command
Rémy O 11 сар өмнө
parent
commit
2eea03d86a

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 359 - 212
ggml/src/ggml-vulkan/ggml-vulkan.cpp


+ 51 - 0
ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp

@@ -0,0 +1,51 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+layout (constant_id = 0) const uint BLOCK_SIZE = 32;
+
+shared FLOAT_TYPE tmpmax[BLOCK_SIZE];
+shared uint tmp[BLOCK_SIZE];
+
+void main() {
+    const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
+    const uint col = gl_LocalInvocationID.x;
+
+    if (col >= p.KX) {
+        return;
+    }
+    A_TYPE amax = data_a[row*p.KX + col];
+    tmp[col] = col;
+
+    for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) {
+        A_TYPE val = data_a[row*p.KX + i];
+        if (val > amax) {
+            amax = val;
+            tmp[col] = i;
+        }
+    }
+    tmpmax[col] = amax;
+
+    barrier();
+    [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) {
+        if (col < s && col + s < p.KX) {
+            if (tmpmax[col] < tmpmax[col + s]) {
+                tmpmax[col] = tmpmax[col + s];
+                tmp[col] = tmp[col + s];
+            }
+        }
+        barrier();
+    }
+
+    if (col == 0) {
+        data_d[row] = D_TYPE(tmp[0]);
+    }
+}

+ 31 - 0
ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp

@@ -0,0 +1,31 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+
+#include "types.comp"
+#include "generic_head.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
+layout (binding = 2) buffer D {D_TYPE data_d[];};
+
+const uint CHUNK_SIZE = 512;
+
+void main() {
+    const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
+    const uint col = gl_LocalInvocationID.x;
+
+    uint count = 0;
+    [[unroll]]
+    for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
+        const uint idx = base + i + col;
+        if (idx >= p.KX) {
+            break;
+        }
+        count += uint(data_a[idx] == data_b[idx]);
+    }
+
+    atomicAdd(data_d[0], D_TYPE(count));
+}

+ 42 - 0
ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp

@@ -0,0 +1,42 @@
+#version 450
+
+#include "generic_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) buffer X {A_TYPE x[];};
+layout (binding = 1) readonly buffer G {A_TYPE grad[];};
+layout (binding = 2) buffer GM {A_TYPE gradm[];};
+layout (binding = 3) buffer GV {A_TYPE gradv[];};
+layout (binding = 4) readonly buffer P {float params[7];};
+
+void main() {
+    const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+    if (i >= p.KX) {
+        return;
+    }
+
+    const float alpha  = params[0];
+    const float beta1  = params[1];
+    const float beta2  = params[2];
+    const float eps    = params[3];
+    const float wd     = params[4];
+    const float beta1h = params[5];
+    const float beta2h = params[6];
+
+    const float gi = grad[i];
+    const float gmi = gradm[i]*beta1 +    gi*(1.0f - beta1);
+    const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
+
+    gradm[i] = gmi;
+    gradv[i] = gvi;
+
+    const float mh =      gmi*beta1h;
+    const float vh = sqrt(gvi*beta2h) + eps;
+
+    x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
+}

+ 37 - 0
ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp

@@ -0,0 +1,37 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    // Destination multi-index (inlined dst_idx)
+    const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
+    const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
+    const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
+    const uint i12_offset = i12*p.ne11*p.ne10;
+    const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
+    const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
+    const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10;
+
+    // Accumulate from sources
+    A_TYPE acc = A_TYPE(0);
+    for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) {
+        for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) {
+            for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) {
+                for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) {
+                    acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00];
+                }
+            }
+        }
+    }
+
+    data_d[get_doffset() + d_idx] = D_TYPE(acc);
+}

+ 29 - 0
ggml/src/ggml-vulkan/vulkan-shaders/sub.comp

@@ -0,0 +1,29 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+
+#include "types.comp"
+#include "generic_binary_head.comp"
+
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
+void main() {
+    uint idx = get_idx();
+
+    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+    const uint num_iter = 2;
+
+    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+        if (idx >= p.ne) {
+            continue;
+        }
+        uint i00, i01, i02, i03;
+        get_indices(idx, i00, i01, i02, i03);
+
+        data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
+
+        idx += num_threads;
+    }
+}

+ 7 - 0
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -443,6 +443,8 @@ void process_shaders() {
     string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
     string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
 
 
+    string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+
     string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
 
     string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
     string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
@@ -452,6 +454,7 @@ void process_shaders() {
     string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
 
     string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 
 
     string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
 
@@ -501,7 +504,9 @@ void process_shaders() {
 
 
     string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
     string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
 
 
+    string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
 
 
     string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
     string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
@@ -513,6 +518,8 @@ void process_shaders() {
 
 
     string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
     string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 
 
+    string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
+
     for (auto &c : compiles) {
     for (auto &c : compiles) {
         c.wait();
         c.wait();
     }
     }

+ 5 - 5
tests/test-backend-ops.cpp

@@ -1254,7 +1254,7 @@ struct test_count_equal : public test_case {
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
         ggml_set_name(b, "b");
         ggml_set_name(b, "b");
 
 
-        ggml_tensor * b_argmax = ggml_argmax(ctx, a);
+        ggml_tensor * b_argmax = ggml_argmax(ctx, b);
         ggml_set_name(b_argmax, "b_argmax");
         ggml_set_name(b_argmax, "b_argmax");
 
 
         ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
         ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
@@ -1511,6 +1511,7 @@ struct test_cont : public test_case {
 };
 };
 
 
 // GGML_OP_ADD
 // GGML_OP_ADD
+// GGML_OP_SUB
 // GGML_OP_MUL
 // GGML_OP_MUL
 // GGML_OP_DIV
 // GGML_OP_DIV
 struct test_bin_bcast : public test_case {
 struct test_bin_bcast : public test_case {
@@ -3860,7 +3861,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
 
 
-    test_cases.emplace_back(new test_count_equal());
+    test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4,  500, 1, 1}));
+    test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
 
 
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32,    1, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32,    1, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100,  10, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100,  10, 1, 1}));
@@ -3885,8 +3887,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
         test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
-        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
-        test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
     }
     }
 
 
     test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
     test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
@@ -3938,7 +3938,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
     test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
 
 
     auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
     auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
-        for (auto op : {ggml_add, ggml_mul, ggml_div}) {
+        for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
             test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
             test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
         }
         }
     };
     };

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно