@@ -874,6 +874,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
+ __syncthreads();
+
// Write back combined meta data:
#pragma unroll
for (int imeta = 0; imeta < nmeta; ++imeta) {
@@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
KQ[j*D + tid] = -HALF_MAX_HALF;
half2 VKQ[ncols] = {{0.0f, 0.0f}};