|
@@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
|
|
|
const int ne13,
|
|
const int ne13,
|
|
|
const int ne31,
|
|
const int ne31,
|
|
|
const int ne32,
|
|
const int ne32,
|
|
|
|
|
+ const int ne33,
|
|
|
const int nb31,
|
|
const int nb31,
|
|
|
const int nb32,
|
|
const int nb32,
|
|
|
|
|
+ const int nb33,
|
|
|
const int nb01,
|
|
const int nb01,
|
|
|
const int nb02,
|
|
const int nb02,
|
|
|
const int nb03,
|
|
const int nb03,
|
|
@@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|
|
template<int D, int ncols1, int ncols2> // D == head size
|
|
template<int D, int ncols1, int ncols2> // D == head size
|
|
|
__launch_bounds__(D, 1)
|
|
__launch_bounds__(D, 1)
|
|
|
static __global__ void flash_attn_stream_k_fixup(
|
|
static __global__ void flash_attn_stream_k_fixup(
|
|
|
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
|
|
|
|
|
|
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
|
|
constexpr int ncols = ncols1*ncols2;
|
|
constexpr int ncols = ncols1*ncols2;
|
|
|
|
|
|
|
|
const int bidx0 = blockIdx.x;
|
|
const int bidx0 = blockIdx.x;
|
|
@@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
|
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
|
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
|
|
|
|
|
|
|
- const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
|
|
|
- const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
|
|
|
|
|
+ const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
|
|
|
+ const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
|
|
|
|
|
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
|
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
@@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- const int channel = kbc0 / (iter_k*iter_j);
|
|
|
|
|
- const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
|
|
|
|
|
|
+ const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
|
|
|
|
+ const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
|
|
|
|
+ const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
|
|
|
|
|
|
|
if (jt*ncols1 + j >= ne01) {
|
|
if (jt*ncols1 + j >= ne01) {
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
|
|
|
|
|
|
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
|
|
|
|
|
|
|
// Load the partial result that needs a fixup:
|
|
// Load the partial result that needs a fixup:
|
|
|
float dst_val = 0.0f;
|
|
float dst_val = 0.0f;
|
|
@@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
|
int bidx = bidx0 - 1;
|
|
int bidx = bidx0 - 1;
|
|
|
int kbc_stop = kbc0;
|
|
int kbc_stop = kbc0;
|
|
|
while(true) {
|
|
while(true) {
|
|
|
- const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
|
|
|
|
|
|
+ const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
|
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
if (kbc == kbc_stop) { // Did not have any data.
|
|
|
bidx--;
|
|
bidx--;
|
|
|
kbc_stop = kbc;
|
|
kbc_stop = kbc;
|
|
@@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
|
|
|
const float2 * __restrict__ VKQ_meta,
|
|
const float2 * __restrict__ VKQ_meta,
|
|
|
float * __restrict__ dst,
|
|
float * __restrict__ dst,
|
|
|
const int parallel_blocks) {
|
|
const int parallel_blocks) {
|
|
|
- VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
|
|
|
|
- VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
|
|
|
|
- dst += D * gridDim.z*blockIdx.x;
|
|
|
|
|
|
|
+ // Dimension 0: threadIdx.x
|
|
|
|
|
+ // Dimension 1: blockIdx.x
|
|
|
|
|
+ // Dimension 2: blockIdx.y
|
|
|
|
|
+ // Dimension 3: blockIdx.z
|
|
|
|
|
+ // Memory layout is permuted with [0, 2, 1, 3]
|
|
|
|
|
+
|
|
|
|
|
+ const int ne01 = gridDim.x;
|
|
|
|
|
+ const int ne02 = gridDim.y;
|
|
|
|
|
+
|
|
|
|
|
+ const int col = blockIdx.x;
|
|
|
|
|
+ const int head = blockIdx.y;
|
|
|
|
|
+ const int sequence = blockIdx.z;
|
|
|
|
|
+
|
|
|
|
|
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
|
|
|
|
+
|
|
|
|
|
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
|
|
|
|
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
|
|
|
|
|
+ dst += j_dst_unrolled * D;
|
|
|
|
|
|
|
|
const int tid = threadIdx.x;
|
|
const int tid = threadIdx.x;
|
|
|
__builtin_assume(tid < D);
|
|
__builtin_assume(tid < D);
|
|
|
|
|
|
|
|
extern __shared__ float2 meta[];
|
|
extern __shared__ float2 meta[];
|
|
|
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
|
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
|
|
- ((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
|
|
|
|
|
|
|
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
__syncthreads();
|
|
__syncthreads();
|
|
@@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
|
|
|
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
|
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
|
|
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
|
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
|
|
|
|
|
|
|
- VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
|
|
|
|
|
|
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
|
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
|
VKQ_denominator += KQ_max_scale * meta[l].y;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
|
|
|
|
|
|
+ dst[tid] = VKQ_numerator / VKQ_denominator;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
[[noreturn]]
|
|
[[noreturn]]
|
|
@@ -705,8 +723,6 @@ void launch_fattn(
|
|
|
|
|
|
|
|
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
|
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
|
|
|
|
|
|
|
- GGML_ASSERT(Q->ne[3] == 1);
|
|
|
|
|
-
|
|
|
|
|
ggml_cuda_pool & pool = ctx.pool();
|
|
ggml_cuda_pool & pool = ctx.pool();
|
|
|
cudaStream_t main_stream = ctx.stream();
|
|
cudaStream_t main_stream = ctx.stream();
|
|
|
const int id = ggml_cuda_get_device();
|
|
const int id = ggml_cuda_get_device();
|
|
@@ -853,8 +869,8 @@ void launch_fattn(
|
|
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
|
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
|
- mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
|
|
|
|
- mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
|
|
|
|
|
|
+ mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
|
|
|
|
+ mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
|
|
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
|
|
nb11, nb12, nb13,
|
|
nb11, nb12, nb13,
|
|
|
nb21, nb22, nb23,
|
|
nb21, nb22, nb23,
|
|
@@ -869,11 +885,11 @@ void launch_fattn(
|
|
|
|
|
|
|
|
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
|
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
|
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
|
|
- ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
|
|
|
|
|
|
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
|
|
}
|
|
}
|
|
|
} else if (parallel_blocks > 1) {
|
|
} else if (parallel_blocks > 1) {
|
|
|
const dim3 block_dim_combine(DV, 1, 1);
|
|
const dim3 block_dim_combine(DV, 1, 1);
|
|
|
- const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
|
|
|
|
|
|
+ const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
|
|
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
|
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
|
|
|
|
|
|
|
flash_attn_combine_results<DV>
|
|
flash_attn_combine_results<DV>
|