瀏覽代碼

vulkan: fix diag_mask_inf (#11323)

With robustbufferaccess disabled, this shader was showing OOB stores. There
is a bounds check in the code, but the workgrouop dimensions were reversed vs
CUDA and it was running the wrong number of threads. So fix the workgroup
dimensions and disable robustness for this pipeline.
Jeff Bolz 1 年之前
父節點
當前提交
5245729e33
共有 2 個文件被更改,包括 2 次插入2 次删除
  1. 1 1
      ggml/src/ggml-vulkan/ggml-vulkan.cpp
  2. 1 1
      ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp

+ 1 - 1
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -2012,7 +2012,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
 
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);

+ 1 - 1
ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp

@@ -12,7 +12,7 @@ layout (push_constant) uniform parameter
 
 #include "types.comp"
 
-layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
 layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};