|
@@ -168,6 +168,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
for (int j = 0; j < ncols; ++j) {
|
|
for (int j = 0; j < ncols; ++j) {
|
|
|
KQ[j*D + tid] = -HALF_MAX_HALF;
|
|
KQ[j*D + tid] = -HALF_MAX_HALF;
|
|
|
}
|
|
}
|
|
|
|
|
+ __syncthreads();
|
|
|
|
|
|
|
|
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
|
half2 VKQ[ncols] = {{0.0f, 0.0f}};
|
|
|
|
|
|