* kqmax_new_j in every thread within warp is same after operate at line 199,this reduce can be omit * same problem in vec32 --------- Co-authored-by: ZhaoXiaoYu <zhao.xiaoyu@zte.com.cn>
@@ -220,7 +220,6 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
- kqmax_new_j = warp_reduce_max(kqmax_new_j);
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
}
@@ -206,7 +206,6 @@ static __global__ void flash_attn_vec_ext_f32(
float kqmax_new_j = kqmax_new_arr[j];