Преглед изворни кода

vulkan: Optimize binary ops (#10270)

Reuse the index calculations across all of src0/src1/dst. Add a shader
variant for when src0/src1 are the same dimensions and additional modulus
for src1 aren't needed. Div/mod are slow, so add "fast" div/mod that
have a fast path when the calculation isn't needed or can be done more
cheaply.
Jeff Bolz пре 1 година
родитељ
комит
af148c9386

+ 16 - 11
ggml/src/ggml-vulkan.cpp

@@ -192,9 +192,10 @@ struct vk_device_struct {
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_acc_f32;
     vk_pipeline pipeline_acc_f32;
-    vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
-    vk_pipeline pipeline_mul_f32;
-    vk_pipeline pipeline_div_f32;
+    vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
+    vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
+    vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
+    vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
     vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
     vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
     vk_pipeline pipeline_upscale_f32;
     vk_pipeline pipeline_upscale_f32;
     vk_pipeline pipeline_scale_f32;
     vk_pipeline pipeline_scale_f32;
@@ -1456,13 +1457,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
 
-    ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
 
 
     ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
 
-    ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
 
 
     ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3801,20 +3806,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return nullptr;
         return nullptr;
     case GGML_OP_ADD:
     case GGML_OP_ADD:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_add_f32;
+            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
         }
         }
         if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
         if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
-            return ctx->device->pipeline_add_f16_f32_f16;
+            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
         }
         }
         return nullptr;
         return nullptr;
     case GGML_OP_MUL:
     case GGML_OP_MUL:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_mul_f32;
+            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
         }
         }
         return nullptr;
         return nullptr;
     case GGML_OP_DIV:
     case GGML_OP_DIV:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_div_f32;
+            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
         }
         }
         return nullptr;
         return nullptr;
     case GGML_OP_CONCAT:
     case GGML_OP_CONCAT:

+ 7 - 2
ggml/src/vulkan-shaders/acc.comp

@@ -3,6 +3,8 @@
 #include "types.comp"
 #include "types.comp"
 #include "generic_binary_head.comp"
 #include "generic_binary_head.comp"
 
 
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
 void main() {
     const uint idx = gl_GlobalInvocationID.x;
     const uint idx = gl_GlobalInvocationID.x;
     if (idx >= p.ne) {
     if (idx >= p.ne) {
@@ -15,10 +17,13 @@ void main() {
     const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
     const uint oy = (src1_i - (oz * p.nb02)) / p.nb01;
     const uint ox = src1_i % p.nb01;
     const uint ox = src1_i % p.nb01;
 
 
+    uint i00, i01, i02, i03;
+    get_indices(idx, i00, i01, i02, i03);
+
     if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
     if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
-        data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
     } else {
     } else {
-        data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]));
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]));
     }
     }
 }
 }
 
 

+ 20 - 5
ggml/src/vulkan-shaders/add.comp

@@ -1,14 +1,29 @@
 #version 450
 #version 450
 
 
+#extension GL_EXT_shader_16bit_storage : require
+
 #include "types.comp"
 #include "types.comp"
 #include "generic_binary_head.comp"
 #include "generic_binary_head.comp"
 
 
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
 void main() {
-    const uint idx = get_idx();
+    uint idx = get_idx();
 
 
-    if (idx >= p.ne) {
-        return;
-    }
+    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+    const uint num_iter = 2;
 
 
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) + FLOAT_TYPE(data_b[src1_idx(idx)]));
+    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+        if (idx >= p.ne) {
+            continue;
+        }
+        uint i00, i01, i02, i03;
+        get_indices(idx, i00, i01, i02, i03);
+
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
+
+        idx += num_threads;
+    }
 }
 }

+ 2 - 0
ggml/src/vulkan-shaders/concat.comp

@@ -3,6 +3,8 @@
 #include "types.comp"
 #include "types.comp"
 #include "generic_binary_head.comp"
 #include "generic_binary_head.comp"
 
 
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
 void main() {
     const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
     const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
     const int dim = p.param3;
     const int dim = p.param3;

+ 18 - 5
ggml/src/vulkan-shaders/div.comp

@@ -3,12 +3,25 @@
 #include "types.comp"
 #include "types.comp"
 #include "generic_binary_head.comp"
 #include "generic_binary_head.comp"
 
 
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
 void main() {
-    const uint idx = get_idx();
+    uint idx = get_idx();
 
 
-    if (idx >= p.ne) {
-        return;
-    }
+    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+    const uint num_iter = 2;
 
 
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) / FLOAT_TYPE(data_b[src1_idx(idx)]));
+    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+        if (idx >= p.ne) {
+            continue;
+        }
+        uint i00, i01, i02, i03;
+        get_indices(idx, i00, i01, i02, i03);
+
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
+
+        idx += num_threads;
+    }
 }
 }

+ 32 - 24
ggml/src/vulkan-shaders/generic_binary_head.comp

@@ -1,4 +1,5 @@
 #extension GL_EXT_shader_16bit_storage : require
 #extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_control_flow_attributes : require
 
 
 layout (push_constant) uniform parameter
 layout (push_constant) uniform parameter
 {
 {
@@ -10,43 +11,50 @@ layout (push_constant) uniform parameter
     float param1; float param2; int param3;
     float param1; float param2; int param3;
 } p;
 } p;
 
 
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
-
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
 layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
 layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 
 
+// true if src0/src1 are the same shape and the indices can be reused without additional modulus
+layout(constant_id = 0) const bool norepeat = false;
+
 uint get_idx() {
 uint get_idx() {
     return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
     return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
 }
 }
 
 
-uint src0_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
-    const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
-    const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
-    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
+uint fastmod(uint a, uint b) {
+    if ((b & (b-1)) == 0) {
+        return a & (b-1);
+    }
+    return a % b;
 }
 }
 
 
-uint src1_idx(uint idx) {
-    const uint i03 = idx / (p.ne02*p.ne01*p.ne00);
+uint fastdiv(uint a, uint b) {
+    return (a < b) ? 0 : (a / b);
+}
+
+void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
+    i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
     const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
     const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
-    const uint i02 = (idx - i03_offset) / (p.ne01*p.ne00);
+    i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
     const uint i02_offset = i02*p.ne01*p.ne00;
     const uint i02_offset = i02*p.ne01*p.ne00;
-    const uint i01 = (idx - i03_offset - i02_offset) / p.ne00;
-    const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+    i01 = (idx - i03_offset - i02_offset) / p.ne00;
+    i00 = idx - i03_offset - i02_offset - i01*p.ne00;
+}
+
+uint src0_idx(uint i00, uint i01, uint i02, uint i03) {
+    return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00;
+}
 
 
-    return (i03 % p.ne13)*p.nb13 + (i02 % p.ne12)*p.nb12 + (i01 % p.ne11)*p.nb11 + (i00 % p.ne10)*p.nb10;
+uint src1_idx(uint i00, uint i01, uint i02, uint i03) {
+    if (norepeat) {
+        return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10;
+    } else {
+        return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10;
+    }
 }
 }
 
 
-uint dst_idx(uint idx) {
-    const uint i23 = idx / (p.ne22*p.ne21*p.ne20);
-    const uint i23_offset = i23 * p.ne22*p.ne21*p.ne20;
-    const uint i22 = (idx - i23_offset) / (p.ne21*p.ne20);
-    const uint i22_offset = i22*p.ne21*p.ne20;
-    const uint i21 = (idx - i23_offset - i22_offset) / p.ne20;
-    const uint i20 = idx - i23_offset - i22_offset - i21*p.ne20;
-    return i23*p.nb23 + i22*p.nb22 + i21*p.nb21 + i20*p.nb20;
+uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
+    return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20;
 }
 }

+ 2 - 0
ggml/src/vulkan-shaders/get_rows.comp

@@ -3,6 +3,8 @@
 #include "types.comp"
 #include "types.comp"
 #include "generic_binary_head.comp"
 #include "generic_binary_head.comp"
 
 
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
 void main() {
     const uint i00 = gl_GlobalInvocationID.x;
     const uint i00 = gl_GlobalInvocationID.x;
     const uint i10 = gl_GlobalInvocationID.y;
     const uint i10 = gl_GlobalInvocationID.y;

+ 2 - 0
ggml/src/vulkan-shaders/get_rows_quant.comp

@@ -4,6 +4,8 @@
 #include "generic_binary_head.comp"
 #include "generic_binary_head.comp"
 #include "dequant_funcs.comp"
 #include "dequant_funcs.comp"
 
 
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
 void main() {
     const uint i00 = (gl_GlobalInvocationID.x)*2;
     const uint i00 = (gl_GlobalInvocationID.x)*2;
     const uint i10 = gl_GlobalInvocationID.y;
     const uint i10 = gl_GlobalInvocationID.y;

+ 18 - 5
ggml/src/vulkan-shaders/mul.comp

@@ -3,12 +3,25 @@
 #include "types.comp"
 #include "types.comp"
 #include "generic_binary_head.comp"
 #include "generic_binary_head.comp"
 
 
+const uint num_threads = 256;
+
+layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
+
 void main() {
 void main() {
-    const uint idx = get_idx();
+    uint idx = get_idx();
 
 
-    if (idx >= p.ne) {
-        return;
-    }
+    // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
+    const uint num_iter = 2;
 
 
-    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(idx)]) * FLOAT_TYPE(data_b[src1_idx(idx)]));
+    [[unroll]] for (uint i = 0; i < num_iter; ++i) {
+        if (idx >= p.ne) {
+            continue;
+        }
+        uint i00, i01, i02, i03;
+        get_indices(idx, i00, i01, i02, i03);
+
+        data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
+
+        idx += num_threads;
+    }
 }
 }