|
@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|
|
nullptr;
|
|
nullptr;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
|
|
|
|
|
-#ifdef __clang__
|
|
|
|
|
-#pragma clang diagnostic push
|
|
|
|
|
-#pragma clang diagnostic ignored "-Wpass-failed"
|
|
|
|
|
-#endif // __clang__
|
|
|
|
|
-
|
|
|
|
|
-template<int D, int ncols, int KQ_stride> // D == head size
|
|
|
|
|
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
|
|
|
|
+template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
|
|
|
__launch_bounds__(D, 1)
|
|
__launch_bounds__(D, 1)
|
|
|
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
|
|
static __global__ void flash_attn_stream_k_fixup(
|
|
static __global__ void flash_attn_stream_k_fixup(
|
|
|
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
|
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
|
|
- const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
|
|
|
|
-
|
|
|
|
|
- const int iter_k = ne11 / KQ_stride;
|
|
|
|
|
- const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
|
|
|
|
|
|
+ constexpr int ncols = ncols1*ncols2;
|
|
|
|
|
|
|
|
const int bidx0 = blockIdx.x;
|
|
const int bidx0 = blockIdx.x;
|
|
|
|
|
+ const int j = blockIdx.y;
|
|
|
|
|
+ const int c = blockIdx.z;
|
|
|
|
|
+ const int jc = j*ncols2 + c;
|
|
|
|
|
+ const int tid = threadIdx.x;
|
|
|
|
|
+
|
|
|
|
|
+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
|
|
|
|
+
|
|
|
|
|
+ const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
|
|
|
|
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
|
|
|
|
|
|
- const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
|
|
|
|
- const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
|
|
|
|
|
|
+ const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
|
|
|
+ const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
|
|
|
|
|
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
@@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
const int channel = kbc0 / (iter_k*iter_j);
|
|
const int channel = kbc0 / (iter_k*iter_j);
|
|
|
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
|
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
|
|
|
|
|
|
|
- dst += jt*ncols*ne02*D + channel*D;
|
|
|
|
|
|
|
+ if (jt*ncols1 + j >= ne01) {
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- // Load the partial result that needs a fixup:
|
|
|
|
|
- float dst_val[ncols] = {0.0f};
|
|
|
|
|
- float max_val[ncols] = {0.0f};
|
|
|
|
|
- float rowsum[ncols] = {0.0f};
|
|
|
|
|
-#pragma unroll
|
|
|
|
|
- for (int j = 0; j < ncols; ++j) {
|
|
|
|
|
- if (jt*ncols + j >= ne01) {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
- dst_val[j] = dst[j*ne02*D + threadIdx.x];
|
|
|
|
|
|
|
+ dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
|
|
|
|
|
|
|
- const float2 tmp = dst_fixup[bidx0*ncols + j];
|
|
|
|
|
- max_val[j] = tmp.x;
|
|
|
|
|
- rowsum[j] = tmp.y;
|
|
|
|
|
|
|
+ // Load the partial result that needs a fixup:
|
|
|
|
|
+ float dst_val = 0.0f;
|
|
|
|
|
+ float max_val = 0.0f;
|
|
|
|
|
+ float rowsum = 0.0f;
|
|
|
|
|
+ {
|
|
|
|
|
+ dst_val = *dst;
|
|
|
|
|
+
|
|
|
|
|
+ const float2 tmp = dst_fixup[bidx0*ncols + jc];
|
|
|
|
|
+ max_val = tmp.x;
|
|
|
|
|
+ rowsum = tmp.y;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Iterate over previous blocks and compute the combined results.
|
|
// Iterate over previous blocks and compute the combined results.
|
|
@@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
int bidx = bidx0 - 1;
|
|
int bidx = bidx0 - 1;
|
|
|
int kbc_stop = kbc0;
|
|
int kbc_stop = kbc0;
|
|
|
while(true) {
|
|
while(true) {
|
|
|
- const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
|
|
|
|
|
|
|
+ const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
|
bidx--;
|
|
bidx--;
|
|
|
kbc_stop = kbc;
|
|
kbc_stop = kbc;
|
|
|
continue;
|
|
continue;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-#pragma unroll
|
|
|
|
|
- for (int j = 0; j < ncols; ++j) {
|
|
|
|
|
- if (jt*ncols + j >= ne01) {
|
|
|
|
|
- break;
|
|
|
|
|
- }
|
|
|
|
|
- const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
|
|
|
|
|
|
|
+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
|
|
|
|
|
|
|
|
- const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
|
|
|
|
|
|
|
+ const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
|
|
|
|
|
|
|
|
- // Scale the current and new value accumulators depending on the max. values.
|
|
|
|
|
- const float max_val_new = fmaxf(max_val[j], tmp.x);
|
|
|
|
|
|
|
+ // Scale the current and new value accumulators depending on the max. values.
|
|
|
|
|
+ const float max_val_new = fmaxf(max_val, tmp.x);
|
|
|
|
|
|
|
|
- const float diff_val = max_val[j] - max_val_new;
|
|
|
|
|
- const float diff_add = tmp.x - max_val_new;
|
|
|
|
|
|
|
+ const float diff_val = max_val - max_val_new;
|
|
|
|
|
+ const float diff_add = tmp.x - max_val_new;
|
|
|
|
|
|
|
|
- const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
|
|
|
|
- const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
|
|
|
|
|
|
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
|
|
|
|
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
|
|
|
|
|
|
|
- dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
|
|
|
|
|
- rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
|
|
|
|
|
|
|
+ dst_val = scale_val*dst_val + scale_add*dst_add;
|
|
|
|
|
+ rowsum = scale_val*rowsum + scale_add*tmp.y;
|
|
|
|
|
|
|
|
- max_val[j] = max_val_new;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ max_val = max_val_new;
|
|
|
|
|
|
|
|
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
|
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
|
|
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
|
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
|
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Write back final result:
|
|
// Write back final result:
|
|
|
-#pragma unroll
|
|
|
|
|
- for (int j = 0; j < ncols; ++j) {
|
|
|
|
|
- if (jt*ncols + j >= ne01) {
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
- dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ *dst = dst_val / rowsum;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-#ifdef __clang__
|
|
|
|
|
-#pragma clang diagnostic pop
|
|
|
|
|
-#endif // __clang__
|
|
|
|
|
-
|
|
|
|
|
template<int D, int parallel_blocks> // D == head size
|
|
template<int D, int parallel_blocks> // D == head size
|
|
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
__launch_bounds__(D, 1)
|
|
__launch_bounds__(D, 1)
|
|
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// parallel_blocks == 0 is stream-k decomposition
|
|
// parallel_blocks == 0 is stream-k decomposition
|
|
|
-template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
|
|
|
|
|
|
|
+template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
|
|
void launch_fattn(
|
|
void launch_fattn(
|
|
|
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
|
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
|
|
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
|
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
|
|
) {
|
|
) {
|
|
|
|
|
+ constexpr int ncols = ncols1 * ncols2;
|
|
|
|
|
+
|
|
|
const ggml_tensor * Q = dst->src[0];
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
const ggml_tensor * K = dst->src[1];
|
|
const ggml_tensor * K = dst->src[1];
|
|
|
const ggml_tensor * V = dst->src[2];
|
|
const ggml_tensor * V = dst->src[2];
|
|
@@ -763,25 +747,26 @@ void launch_fattn(
|
|
|
nb23 = nb23*bs*sizeof(half)/ts;
|
|
nb23 = nb23*bs*sizeof(half)/ts;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
|
|
|
|
|
- const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
|
|
|
|
|
|
|
+ const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
|
|
|
|
+ const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
|
|
|
|
|
|
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
|
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
|
|
dim3 blocks_num;
|
|
dim3 blocks_num;
|
|
|
if (parallel_blocks == 0) {
|
|
if (parallel_blocks == 0) {
|
|
|
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
|
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
|
|
- const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
|
|
|
|
|
- const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
|
|
|
|
|
|
|
+ const int max_blocks = 2*nsm;
|
|
|
|
|
+ const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
|
|
|
|
+ const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
|
|
|
|
|
|
|
- const int nblocks_stream_k = 2*nsm;
|
|
|
|
|
|
|
+ const int nblocks_stream_k = max_blocks;
|
|
|
|
|
|
|
|
- const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
|
|
|
|
|
|
|
+ const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
|
|
|
|
|
|
|
|
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
|
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
|
|
blocks_num.y = 1;
|
|
blocks_num.y = 1;
|
|
|
blocks_num.z = 1;
|
|
blocks_num.z = 1;
|
|
|
|
|
|
|
|
- dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
|
|
|
|
|
|
|
+ dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
|
|
} else {
|
|
} else {
|
|
|
blocks_num.x = parallel_blocks*ntiles_x;
|
|
blocks_num.x = parallel_blocks*ntiles_x;
|
|
|
blocks_num.y = Q->ne[2];
|
|
blocks_num.y = Q->ne[2];
|
|
@@ -793,7 +778,6 @@ void launch_fattn(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-
|
|
|
|
|
float scale = 1.0f;
|
|
float scale = 1.0f;
|
|
|
float max_bias = 0.0f;
|
|
float max_bias = 0.0f;
|
|
|
float logit_softcap = 0.0f;
|
|
float logit_softcap = 0.0f;
|
|
@@ -832,9 +816,9 @@ void launch_fattn(
|
|
|
if constexpr (parallel_blocks == 0) {
|
|
if constexpr (parallel_blocks == 0) {
|
|
|
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
|
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
|
|
const dim3 block_dim_combine(D, 1, 1);
|
|
const dim3 block_dim_combine(D, 1, 1);
|
|
|
- const dim3 blocks_num_combine = blocks_num;
|
|
|
|
|
|
|
+ const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
|
|
|
|
|
|
|
- flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
|
|
|
|
|
|
|
+ flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
|
|
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
|
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
|
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
|
|
}
|
|
}
|