Browse Source

vulkan: workaround FA compile failures on macos (#13517)

Jeff Bolz 8 tháng trước cách đây
mục cha
commit
ab3971f2a0
1 tập tin đã thay đổi với 4 bổ sung3 xóa
  1. 4 3
      ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

+ 4 - 3
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

@@ -12,6 +12,7 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
+layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
 layout (constant_id = 1) const uint32_t Br = 1;
 layout (constant_id = 2) const uint32_t Bc = 32;
 layout (constant_id = 3) const uint32_t D = 32;
@@ -19,7 +20,7 @@ layout (constant_id = 3) const uint32_t D = 32;
 layout (constant_id = 5) const uint32_t D_split = 16;
 const uint32_t D_per_thread = D / D_split;
 
-const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split;
+const uint32_t cols_per_iter = WorkGroupSize / D_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
 
 layout (push_constant) uniform parameter {
@@ -134,8 +135,8 @@ ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const i
     return ACC_TYPE(pow(base, ACC_TYPE(exph)));
 }
 
-shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
-shared vec4 tmpshv4[gl_WorkGroupSize.x];
+shared FLOAT_TYPE tmpsh[WorkGroupSize];
+shared vec4 tmpshv4[WorkGroupSize];
 
 shared float masksh[Bc][Br];
 shared vec4 Qf[Br][D / 4];