|
|
@@ -85,19 +85,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int stride_Q = nb01 / sizeof(float);
|
|
|
const int stride_KV = nb11 / sizeof(half);
|
|
|
|
|
|
- half slopeh = __float2half(1.0f);
|
|
|
- half2 slope2 = make_half2(1.0f, 1.0f);
|
|
|
-
|
|
|
- // ALiBi
|
|
|
- if (max_bias > 0.0f) {
|
|
|
- const uint32_t h = blockIdx.y;
|
|
|
-
|
|
|
- const float base = h < n_head_log2 ? m0 : m1;
|
|
|
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
|
-
|
|
|
- slopeh = __float2half(powf(base, exph));
|
|
|
- slope2 = make_half2(slopeh, slopeh);
|
|
|
- }
|
|
|
+ const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
|
|
|
+ const half slopeh = __float2half(slopef);
|
|
|
+ const half2 slope2 = make_half2(slopef, slopef);
|
|
|
|
|
|
frag_b Q_b[D/16][ncols/frag_n];
|
|
|
|
|
|
@@ -439,108 +429,37 @@ static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
|
|
|
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
|
|
|
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
|
|
|
|
|
|
-template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
|
|
|
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
|
|
- ggml_cuda_pool & pool, cudaStream_t main_stream
|
|
|
-) {
|
|
|
- ggml_cuda_pool_alloc<float> dst_tmp(pool);
|
|
|
- ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
|
|
|
-
|
|
|
- if (parallel_blocks > 1) {
|
|
|
- dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
|
|
|
- dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
|
|
|
- }
|
|
|
-
|
|
|
- constexpr int frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
|
|
|
- const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
|
|
- const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
|
|
- const int shmem = 0;
|
|
|
-
|
|
|
- float scale = 1.0f;
|
|
|
- float max_bias = 0.0f;
|
|
|
-
|
|
|
- memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
|
|
|
- memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
|
|
|
-
|
|
|
- const uint32_t n_head = Q->ne[2];
|
|
|
- const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
|
-
|
|
|
- const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
|
- const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
-
|
|
|
- flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
|
|
|
- <<<blocks_num, block_dim, shmem, main_stream>>> (
|
|
|
- (const char *) Q->data,
|
|
|
- (const char *) K->data,
|
|
|
- (const char *) V->data,
|
|
|
- mask ? ((const char *) mask->data) : nullptr,
|
|
|
- (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
|
|
- scale, max_bias, m0, m1, n_head_log2,
|
|
|
- Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
|
|
- K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
|
|
- mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
|
|
- Q->nb[1], Q->nb[2], Q->nb[3],
|
|
|
- K->nb[1], K->nb[2], K->nb[3],
|
|
|
- KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
|
|
- );
|
|
|
- CUDA_CHECK(cudaGetLastError());
|
|
|
-
|
|
|
- if ((parallel_blocks) == 1) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- const dim3 block_dim_combine(D, 1, 1);
|
|
|
- const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
|
|
|
- const int shmem_combine = 0;
|
|
|
-
|
|
|
- flash_attn_combine_results<D, parallel_blocks>
|
|
|
- <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
|
|
|
- (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
|
|
|
- CUDA_CHECK(cudaGetLastError());
|
|
|
-}
|
|
|
+template <int D, int cols_per_block, int nwarps, typename KQ_acc_t>
|
|
|
+void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
+ const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
-template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
|
|
|
- const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
|
|
|
- const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
|
|
|
-) {
|
|
|
+ constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
|
|
|
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
|
|
|
+ const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
|
|
|
|
|
if (4*blocks_num_pb1 < 2*nsm) {
|
|
|
- launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
|
|
+ constexpr int parallel_blocks = 4;
|
|
|
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
|
|
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
|
|
return;
|
|
|
}
|
|
|
if (2*blocks_num_pb1 < 2*nsm) {
|
|
|
- launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
|
|
+ constexpr int parallel_blocks = 2;
|
|
|
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
|
|
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
|
|
return;
|
|
|
}
|
|
|
- launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
|
|
|
+ constexpr int parallel_blocks = 1;
|
|
|
+ fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
|
|
|
+ launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
|
|
|
}
|
|
|
|
|
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
- const ggml_tensor * Q = dst->src[0];
|
|
|
- const ggml_tensor * K = dst->src[1];
|
|
|
- const ggml_tensor * V = dst->src[2];
|
|
|
-
|
|
|
- const ggml_tensor * mask = dst->src[3];
|
|
|
-
|
|
|
- ggml_tensor * KQV = dst;
|
|
|
-
|
|
|
- GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
|
|
- GGML_ASSERT(K->type == GGML_TYPE_F16);
|
|
|
- GGML_ASSERT(V->type == GGML_TYPE_F16);
|
|
|
- GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
|
|
-
|
|
|
- GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
|
|
- GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
|
|
- "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
|
|
-
|
|
|
- GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
|
|
+ const ggml_tensor * KQV = dst;
|
|
|
+ const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
ggml_cuda_set_device(ctx.device);
|
|
|
-
|
|
|
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
|
- const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
|
|
|
-
|
|
|
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
|
const int32_t precision = KQV->op_params[2];
|
|
|
|
|
|
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
|
|
@@ -582,22 +501,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
constexpr int nwarps = 4;
|
|
|
switch (Q->ne[0]) {
|
|
|
case 64:
|
|
|
- launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
case 80:
|
|
|
- launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
case 96:
|
|
|
- launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
case 112:
|
|
|
- launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
case 128:
|
|
|
- launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
case 256:
|
|
|
- launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
default:
|
|
|
GGML_ASSERT(false);
|
|
|
@@ -608,22 +527,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
constexpr int nwarps = 4;
|
|
|
switch (Q->ne[0]) {
|
|
|
case 64:
|
|
|
- launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
case 80:
|
|
|
- launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
case 96:
|
|
|
- launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
case 112:
|
|
|
- launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
case 128:
|
|
|
- launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
break;
|
|
|
// case 256:
|
|
|
- // launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ // launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
|
|
|
// break;
|
|
|
default:
|
|
|
GGML_ASSERT(false);
|
|
|
@@ -643,16 +562,16 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
constexpr int nwarps = 4;
|
|
|
switch (Q->ne[0]) {
|
|
|
case 64:
|
|
|
- launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 96:
|
|
|
- launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 128:
|
|
|
- launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 256:
|
|
|
- launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
default:
|
|
|
GGML_ASSERT(false);
|
|
|
@@ -666,22 +585,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
constexpr int nwarps = 4;
|
|
|
switch (Q->ne[0]) {
|
|
|
case 64:
|
|
|
- launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 80:
|
|
|
- launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 96:
|
|
|
- launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 112:
|
|
|
- launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 128:
|
|
|
- launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 256:
|
|
|
- launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
default:
|
|
|
GGML_ASSERT(false);
|
|
|
@@ -694,22 +613,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
|
constexpr int nwarps = 4;
|
|
|
switch (Q->ne[0]) {
|
|
|
case 64:
|
|
|
- launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 80:
|
|
|
- launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 96:
|
|
|
- launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 112:
|
|
|
- launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 128:
|
|
|
- launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
case 256:
|
|
|
- launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
|
|
|
+ launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
|
|
|
break;
|
|
|
default:
|
|
|
GGML_ASSERT(false);
|