|
|
@@ -7,14 +7,19 @@
|
|
|
#include "fattn-wmma-f16.cuh"
|
|
|
|
|
|
#ifdef FP16_MMA_AVAILABLE
|
|
|
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
#include <mma.h>
|
|
|
+namespace wmma = nvcuda::wmma;
|
|
|
+#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
|
|
|
+#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
|
|
|
+#include <rocwmma/rocwmma.hpp>
|
|
|
+namespace wmma = rocwmma;
|
|
|
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
#endif // FP16_MMA_AVAILABLE
|
|
|
|
|
|
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
|
|
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
|
|
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
-__launch_bounds__(nwarps*WARP_SIZE, 1)
|
|
|
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
+__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
|
|
|
static __global__ void flash_attn_ext_f16(
|
|
|
const char * __restrict__ Q,
|
|
|
const char * __restrict__ K,
|
|
|
@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int ne1,
|
|
|
const int ne2,
|
|
|
const int ne3) {
|
|
|
-#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
+#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
|
|
// Skip unused kernel variants for faster compilation:
|
|
|
if (use_logit_softcap && !(D == 128 || D == 256)) {
|
|
|
NO_DEVICE_CODE;
|
|
|
@@ -60,6 +65,8 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
|
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
|
|
|
|
|
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
+
|
|
|
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
|
|
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
|
|
|
|
|
|
@@ -68,11 +75,11 @@ static __global__ void flash_attn_ext_f16(
|
|
|
constexpr int frag_m = ncols == 8 ? 32 : 16;
|
|
|
constexpr int frag_n = ncols == 8 ? 8 : 16;
|
|
|
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
|
|
|
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
|
|
|
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
|
|
|
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
|
|
|
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
|
|
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
|
|
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
|
|
|
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
|
|
|
+ typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
|
|
|
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
|
|
|
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
|
|
|
|
|
|
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
|
|
|
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
|
|
|
@@ -132,9 +139,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
const int j = j0 + threadIdx.y;
|
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
const int i = i0 + threadIdx.x;
|
|
|
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
|
|
+ if (i0 + warp_size > D/2 && i >= D/2) {
|
|
|
break;
|
|
|
}
|
|
|
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
|
|
|
@@ -146,9 +153,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
const int j = j0 + threadIdx.y;
|
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
|
|
+ for (int i0 = 0; i0 < D; i0 += warp_size) {
|
|
|
const int i = i0 + threadIdx.x;
|
|
|
- if (i0 + WARP_SIZE > D && i >= D) {
|
|
|
+ if (i0 + warp_size > D && i >= D) {
|
|
|
break;
|
|
|
}
|
|
|
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
|
|
|
@@ -162,7 +169,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
for (int i0 = 0; i0 < D; i0 += 16) {
|
|
|
#pragma unroll
|
|
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
|
|
- nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
|
|
+ wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -176,20 +183,20 @@ static __global__ void flash_attn_ext_f16(
|
|
|
frag_c_KQ KQ_c[ncols/frag_n];
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
|
|
- nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
|
|
|
+ wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
|
|
|
}
|
|
|
#pragma unroll
|
|
|
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
|
|
|
frag_a_K K_a;
|
|
|
- nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
|
|
+ wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
|
|
- nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
|
|
+ wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
|
|
|
}
|
|
|
}
|
|
|
#pragma unroll
|
|
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
|
|
- nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
|
|
|
+ wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -202,27 +209,27 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int j = j0 + threadIdx.y;
|
|
|
|
|
|
if (std::is_same<KQ_acc_t, float>::value) {
|
|
|
- float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
|
|
|
+ float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
|
|
|
#pragma unroll
|
|
|
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
|
|
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
|
|
const int k = k0 + threadIdx.x;
|
|
|
|
|
|
- KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
|
|
|
+ KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
|
|
|
|
|
|
if (use_logit_softcap) {
|
|
|
- KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
|
|
|
+ KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
float KQ_max_new = KQ_max_f[j0/nwarps];
|
|
|
#pragma unroll
|
|
|
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
|
|
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
|
|
const int k = k0 + threadIdx.x;
|
|
|
|
|
|
- KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
|
|
- KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
|
|
|
+ KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
|
|
|
+ KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
|
|
|
}
|
|
|
- KQ_max_new = warp_reduce_max(KQ_max_new);
|
|
|
+ KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
|
|
|
|
|
|
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
|
|
|
KQ_max_scale_f[j0/nwarps] = expf(diff);
|
|
|
@@ -233,48 +240,48 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
|
float KQ_rowsum_add = 0.0f;
|
|
|
#pragma unroll
|
|
|
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
|
|
|
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
|
|
|
const int k = k0 + threadIdx.x;
|
|
|
|
|
|
- const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
|
|
|
- KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
|
|
|
+ const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
|
|
|
+ KQ_f_tmp[k0/warp_size] = expf(diff);
|
|
|
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
|
|
|
- KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
|
|
|
+ KQ_f_tmp[k0/warp_size] = 0.0f;
|
|
|
}
|
|
|
- KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
|
|
|
- KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
|
|
|
+ KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
|
|
|
+ KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
|
|
|
}
|
|
|
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
|
|
+ KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
|
|
|
|
|
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
|
|
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
|
|
|
} else {
|
|
|
- half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
|
|
|
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
|
|
|
#pragma unroll
|
|
|
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
|
|
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
|
|
const int k = k0 + threadIdx.x;
|
|
|
|
|
|
- KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
|
|
|
+ KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
|
|
|
|
|
|
if (use_logit_softcap) {
|
|
|
// There is no dedicated tangens hyperbolicus function for half2.
|
|
|
- KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
|
|
|
- KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
|
|
|
- /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
|
|
|
+ KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
|
|
|
+ KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
|
|
|
+ /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
|
|
|
|
|
|
- KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
|
|
|
+ KQ2_tmp[k0/warp_size] *= logit_softcap_2;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
|
|
|
#pragma unroll
|
|
|
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
|
|
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
|
|
const int k = k0 + threadIdx.x;
|
|
|
|
|
|
- KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
|
|
- KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
|
|
|
+ KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
|
|
|
+ KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
|
|
|
}
|
|
|
- KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
|
|
+ KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
|
|
|
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
|
|
|
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
|
|
|
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
|
|
@@ -283,17 +290,17 @@ static __global__ void flash_attn_ext_f16(
|
|
|
|
|
|
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
|
|
|
#pragma unroll
|
|
|
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
|
|
|
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
|
|
|
const int k = k0 + threadIdx.x;
|
|
|
|
|
|
- const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
|
|
|
- KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
|
|
|
+ const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
|
|
|
+ KQ2_tmp[k0/warp_size] = h2exp(diff);
|
|
|
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
|
|
|
- *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
|
|
|
- KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
|
|
|
- KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
|
|
|
+ *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
|
|
|
+ KQ_rowsum_add += KQ2_tmp[k0/warp_size];
|
|
|
+ KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
|
|
|
}
|
|
|
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
|
|
|
+ KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
|
|
|
|
|
|
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
|
|
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
|
|
|
@@ -308,7 +315,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
#pragma unroll
|
|
|
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
|
|
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
|
|
- nvcuda::wmma::load_matrix_sync(
|
|
|
+ wmma::load_matrix_sync(
|
|
|
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
|
|
|
KQ + j0*(kqar*kqs_padded) + k,
|
|
|
kqar*kqs_padded);
|
|
|
@@ -320,7 +327,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
|
|
- nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
|
|
|
+ wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
|
|
|
}
|
|
|
|
|
|
#pragma unroll
|
|
|
@@ -328,10 +335,10 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
|
|
|
|
|
|
frag_a_V v_a;
|
|
|
- nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
|
|
+ wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
|
|
|
#pragma unroll
|
|
|
for (int j = 0; j < ncols/frag_n; ++j) {
|
|
|
- nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
|
|
+ wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -343,10 +350,10 @@ static __global__ void flash_attn_ext_f16(
|
|
|
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
|
|
|
#pragma unroll
|
|
|
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
|
|
|
- nvcuda::wmma::store_matrix_sync(
|
|
|
+ wmma::store_matrix_sync(
|
|
|
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
|
|
|
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
|
|
|
- D_padded, nvcuda::wmma::mem_col_major);
|
|
|
+ D_padded, wmma::mem_col_major);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -364,9 +371,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
}
|
|
|
|
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
const int i = i0 + threadIdx.x;
|
|
|
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
|
|
|
+ if (i0 + warp_size > D/2 && i >= D/2) {
|
|
|
break;
|
|
|
}
|
|
|
|
|
|
@@ -398,9 +405,9 @@ static __global__ void flash_attn_ext_f16(
|
|
|
}
|
|
|
|
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
|
|
|
+ for (int i0 = 0; i0 < D; i0 += warp_size) {
|
|
|
const int i = i0 + threadIdx.x;
|
|
|
- if (i0 + WARP_SIZE > D && i >= D) {
|
|
|
+ if (i0 + warp_size > D && i >= D) {
|
|
|
break;
|
|
|
}
|
|
|
float dst_val = VKQ[j_VKQ*D_padded + i];
|
|
|
@@ -425,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
|
|
|
}
|
|
|
#else
|
|
|
NO_DEVICE_CODE;
|
|
|
-#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|
|
|
+#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
|
|
|
}
|
|
|
|
|
|
constexpr int get_max_power_of_2(int x) {
|
|
|
@@ -515,6 +522,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
|
|
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
|
|
|
|
|
if (prec != GGML_PREC_DEFAULT) {
|
|
|
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
|
|
@@ -571,7 +579,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
|
|
|
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
+ if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
|
|
|
constexpr int cols_per_block = 8;
|
|
|
switch (Q->ne[0]) {
|
|
|
case 64:
|
|
|
@@ -592,6 +601,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
|
|
|
}
|
|
|
return;
|
|
|
}
|
|
|
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
|
|
|
|
|
if (Q->ne[1] <= 32) {
|
|
|
constexpr int cols_per_block = 16;
|