|
@@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
|
|
for (int j = 0; j < ncols; ++j) {
|
|
for (int j = 0; j < ncols; ++j) {
|
|
|
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
|
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
|
|
|
|
|
|
|
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
|
|
|
|
if (threadIdx.x == 0) {
|
|
if (threadIdx.x == 0) {
|
|
|
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
|
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
|
|
}
|
|
}
|