Przeglądaj źródła

vulkan : mul_mat: fix UB with small warps (ggml/952)

When the device's warp size is less than 16,
it is possible for loadstride_a (mul_mm.comp:114)
and loadstride_b (mul_mm.comp:115) to be set to 0.
Because they are calculated as: the workgroup size,
multiplied by LOAD_VEC_* (which can be 1) and divided by 16.
And the workgroup size is set to be the same as the
warp/subgroup size.

The loadstride_* variables are used as increments in the
loops that populate the buffers used for the multiplication.

When they are 0 they cause an infinite loop.
But infinite loops without side-effects are UB and the
values of loadstride_* are known at compile time.
So, the compiler quietly optimizes all the loops away.
As a consequence, the buffers are not populated and
the multiplication result is just a matrix with all elements
set to 0.

We prevent the UB by making sure that the workgroup size
will never be less than 16, even if our device has a
smaller warp size (e.g. 8).

Signed-off-by: Salvatore Mesoraca <s.mesoraca16@gmail.com>
Salvatore Mesoraca 1 rok temu
rodzic
commit
cb00020504
1 zmienionych plików z 2 dodań i 2 usunięć
  1. 2 2
      ggml/src/ggml-vulkan.cpp

+ 2 - 2
ggml/src/ggml-vulkan.cpp

@@ -1164,11 +1164,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
     // mulmat
     // mulmat
     std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
     std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
     std::initializer_list<uint32_t> warptile_m = { 128,  64,  64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
     std::initializer_list<uint32_t> warptile_m = { 128,  64,  64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
-    std::initializer_list<uint32_t> warptile_s = { device->subgroup_size,  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
+    std::initializer_list<uint32_t> warptile_s = { std::max(device->subgroup_size, 16u),  32,  32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
 
 
     std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
     std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
     std::initializer_list<uint32_t> warptile_mmq_m = { 128,  64,  64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
     std::initializer_list<uint32_t> warptile_mmq_m = { 128,  64,  64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
-    std::initializer_list<uint32_t> warptile_mmq_s = { device->subgroup_size,  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
+    std::initializer_list<uint32_t> warptile_mmq_s = { std::max(device->subgroup_size, 16u),  32,  32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
 
 
     std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
     std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
     std::array<uint32_t, 3> m_wg_denoms = { 64,  64, 1 };
     std::array<uint32_t, 3> m_wg_denoms = { 64,  64, 1 };