Преглед изворни кода

vulkan: fix diag_mask_inf (#11323)

With robustbufferaccess disabled, this shader was showing OOB stores. There
is a bounds check in the code, but the workgrouop dimensions were reversed vs
CUDA and it was running the wrong number of threads. So fix the workgroup
dimensions and disable robustness for this pipeline.
Jeff Bolz пре 11 месеци
родитељ
комит
5245729e33

+ 1 - 1
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -2012,7 +2012,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
 
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);

+ 1 - 1
ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp

@@ -12,7 +12,7 @@ layout (push_constant) uniform parameter
 
 #include "types.comp"
 
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
 layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};