|
@@ -343,7 +343,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
|
|
|
for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
|
|
for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
|
|
|
const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
|
|
const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne;
|
|
|
|
|
|
|
|
- const half2 zero[cpy_ne] = {{0.0f, 0.0f}};
|
|
|
|
|
|
|
+ const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}};
|
|
|
ggml_cuda_memcpy_1<cpy_nb>(
|
|
ggml_cuda_memcpy_1<cpy_nb>(
|
|
|
tile_KV + i*(J/2 + J_padding) + j,
|
|
tile_KV + i*(J/2 + J_padding) + j,
|
|
|
!oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
|
|
!oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
|
|
@@ -394,11 +394,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile(
|
|
|
const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
|
|
const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2);
|
|
|
|
|
|
|
|
const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
|
|
const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}};
|
|
|
- half2 tmp_h2[cpy_ne/2];
|
|
|
|
|
|
|
+ __align__(16) half2 tmp_h2[cpy_ne/2];
|
|
|
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
|
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
|
|
tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
|
|
tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
|
|
|
|
|
|
|
|
- float2 tmp_f2[cpy_ne/2];
|
|
|
|
|
|
|
+ __align__(16) float2 tmp_f2[cpy_ne/2];
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int l = 0; l < cpy_ne/2; ++l) {
|
|
for (int l = 0; l < cpy_ne/2; ++l) {
|
|
|
tmp_f2[l] = __half22float2(tmp_h2[l]);
|
|
tmp_f2[l] = __half22float2(tmp_h2[l]);
|
|
@@ -445,14 +445,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ(
|
|
|
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
|
|
static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
|
|
for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
|
|
|
- half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
|
|
|
|
|
- half2 Q_k[cpw][cpy_ne];
|
|
|
|
|
|
|
+ __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne];
|
|
|
|
|
+ __align__(16) half2 Q_k[cpw][cpy_ne];
|
|
|
#else
|
|
#else
|
|
|
static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
|
|
static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
|
|
for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
|
|
|
- float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
|
|
|
|
|
- float Q_k[cpw][cpy_ne];
|
|
|
|
|
|
|
+ __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
|
|
|
|
|
+ __align__(16) float Q_k[cpw][cpy_ne];
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
@@ -602,9 +602,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
|
|
for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
|
|
|
|
|
|
|
+ __align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs];
|
|
|
#else
|
|
#else
|
|
|
- float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
|
|
|
|
|
|
|
+ __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
@@ -664,8 +664,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int k1 = 0; k1 < nbatch_V; k1 += np) {
|
|
for (int k1 = 0; k1 < nbatch_V; k1 += np) {
|
|
|
- half2 V_k[(DVp/2)/warp_size];
|
|
|
|
|
- half2 KQ_k[cpw];
|
|
|
|
|
|
|
+ __align__(16) half2 V_k[(DVp/2)/warp_size];
|
|
|
|
|
+ __align__(16) half2 KQ_k[cpw];
|
|
|
|
|
|
|
|
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
|
|
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
@@ -676,7 +676,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|
|
for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
|
|
for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
|
|
|
const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
|
|
const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs);
|
|
|
|
|
|
|
|
- half tmp[KQ_cs];
|
|
|
|
|
|
|
+ __align__(16) half tmp[KQ_cs];
|
|
|
ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
|
|
ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>(
|
|
|
&tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
|
|
&tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs);
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
@@ -696,8 +696,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|
|
#else
|
|
#else
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int k1 = 0; k1 < nbatch_V; k1 += np) {
|
|
for (int k1 = 0; k1 < nbatch_V; k1 += np) {
|
|
|
- float2 V_k[(DVp/2)/warp_size];
|
|
|
|
|
- float KQ_k[cpw];
|
|
|
|
|
|
|
+ __align__(16) float2 V_k[(DVp/2)/warp_size];
|
|
|
|
|
+ __align__(16) float KQ_k[cpw];
|
|
|
|
|
|
|
|
constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
|
|
constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
@@ -821,12 +821,12 @@ static __global__ void flash_attn_tile(
|
|
|
__shared__ half2 Q_tmp[ncols * DKQ/2];
|
|
__shared__ half2 Q_tmp[ncols * DKQ/2];
|
|
|
__shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
|
|
__shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV];
|
|
|
__shared__ half KQ[ncols * nbatch_fa];
|
|
__shared__ half KQ[ncols * nbatch_fa];
|
|
|
- half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
|
|
|
|
|
|
|
+ __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
|
|
|
#else
|
|
#else
|
|
|
__shared__ float Q_tmp[ncols * DKQ];
|
|
__shared__ float Q_tmp[ncols * DKQ];
|
|
|
__shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
|
|
__shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV];
|
|
|
__shared__ float KQ[ncols * nbatch_fa];
|
|
__shared__ float KQ[ncols * nbatch_fa];
|
|
|
- float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
|
|
|
|
|
|
|
+ __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
|
float KQ_max[cpw];
|
|
float KQ_max[cpw];
|
|
@@ -849,7 +849,7 @@ static __global__ void flash_attn_tile(
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
|
|
for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
|
|
|
if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
|
|
if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) {
|
|
|
- float tmp_f[cpy_ne_D] = {0.0f};
|
|
|
|
|
|
|
+ __align__(16) float tmp_f[cpy_ne_D] = {0.0f};
|
|
|
ggml_cuda_memcpy_1<sizeof(tmp_f)>
|
|
ggml_cuda_memcpy_1<sizeof(tmp_f)>
|
|
|
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
|
|
(tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float))
|
|
|
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
|
|
+ i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]);
|
|
@@ -860,7 +860,7 @@ static __global__ void flash_attn_tile(
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- half2 tmp_h2[cpy_ne_D/2];
|
|
|
|
|
|
|
+ __align__(16) half2 tmp_h2[cpy_ne_D/2];
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
|
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
|
|
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
|
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
|
@@ -959,7 +959,7 @@ static __global__ void flash_attn_tile(
|
|
|
constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
|
|
constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
|
|
for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
|
|
|
- half2 tmp[cpy_ne_D];
|
|
|
|
|
|
|
+ __align__(16) half2 tmp[cpy_ne_D];
|
|
|
ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
|
|
ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]);
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
|
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
|
@@ -970,7 +970,7 @@ static __global__ void flash_attn_tile(
|
|
|
constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
|
|
constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
|
|
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
|
|
|
- float tmp[cpy_ne_D];
|
|
|
|
|
|
|
+ __align__(16) float tmp[cpy_ne_D];
|
|
|
ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
|
|
ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]);
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
|
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
|
@@ -1033,7 +1033,7 @@ static __global__ void flash_attn_tile(
|
|
|
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
|
|
constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
|
|
for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
|
|
|
- float2 tmp[cpy_ne_D];
|
|
|
|
|
|
|
+ __align__(16) float2 tmp[cpy_ne_D];
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
|
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
|
|
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
|
|
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
|