|
|
@@ -623,8 +623,8 @@ static __global__ void flash_attn_combine_results(
|
|
|
__builtin_assume(tid < D);
|
|
|
|
|
|
extern __shared__ float2 meta[];
|
|
|
- if (tid < 2*parallel_blocks) {
|
|
|
- ((float *) meta)[threadIdx.x] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + tid];
|
|
|
+ for (int i = tid; i < 2*parallel_blocks; i += D) {
|
|
|
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
|
|
|
}
|
|
|
|
|
|
__syncthreads();
|