|
|
@@ -2,20 +2,30 @@
|
|
|
#include "fattn-common.cuh"
|
|
|
#include "fattn-tile.cuh"
|
|
|
|
|
|
-#define FATTN_TILE_NTHREADS 256
|
|
|
+// kq_stride == number of KQ rows to process per iteration
|
|
|
+// kq_nbatch == number of K columns to load in parallel for KQ calculation
|
|
|
|
|
|
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
|
|
|
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
|
|
+ if (GGML_CUDA_CC_IS_RDNA(cc)) {
|
|
|
+ switch (D) {
|
|
|
+ case 64:
|
|
|
+ return 128;
|
|
|
+ case 128:
|
|
|
+ case 256:
|
|
|
+ return ncols <= 16 ? 128 : 64;
|
|
|
+ default:
|
|
|
+ GGML_ABORT("fatal error");
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
+ }
|
|
|
switch (D) {
|
|
|
case 64:
|
|
|
- return 64;
|
|
|
+ return ncols == 32 ? 128 : 64;
|
|
|
case 128:
|
|
|
+ return ncols == 32 ? 64 : 32;
|
|
|
case 256:
|
|
|
- if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
|
|
|
- return ncols <= 16 ? 64 : 32;
|
|
|
- } else {
|
|
|
- return 64;
|
|
|
- }
|
|
|
+ return 32;
|
|
|
default:
|
|
|
GGML_ABORT("fatal error");
|
|
|
return -1;
|
|
|
@@ -49,24 +59,28 @@ static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int
|
|
|
|
|
|
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
|
|
|
#ifdef GGML_USE_HIP
|
|
|
+#ifdef RDNA
|
|
|
switch (D) {
|
|
|
case 64:
|
|
|
- return 64;
|
|
|
+ return 128;
|
|
|
case 128:
|
|
|
-#if defined(GCN) || defined(CDNA)
|
|
|
- return ncols <= 16 ? 64 : 32;
|
|
|
-#else
|
|
|
- return 64;
|
|
|
-#endif // defined(GCN) || defined(CDNA)
|
|
|
case 256:
|
|
|
-#if defined(GCN) || defined(CDNA)
|
|
|
- return ncols <= 16 ? 64 : 32;
|
|
|
+ return ncols <= 16 ? 128 : 64;
|
|
|
+ default:
|
|
|
+ return -1;
|
|
|
+ }
|
|
|
#else
|
|
|
- return 64;
|
|
|
-#endif // defined(GCN) || defined(CDNA)
|
|
|
+ switch (D) {
|
|
|
+ case 64:
|
|
|
+ return ncols == 32 ? 128 : 64;
|
|
|
+ case 128:
|
|
|
+ return ncols == 32 ? 64 : 32;
|
|
|
+ case 256:
|
|
|
+ return 32;
|
|
|
default:
|
|
|
return -1;
|
|
|
}
|
|
|
+#endif // RDNA
|
|
|
#else
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
switch (D) {
|
|
|
@@ -100,17 +114,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
|
|
case 64:
|
|
|
return 64;
|
|
|
case 128:
|
|
|
-#if defined(GCN) || defined(CDNA)
|
|
|
- return ncols <= 16 ? 64 : 128;
|
|
|
-#else
|
|
|
- return 64;
|
|
|
-#endif // defined(GCN) || defined(CDNA)
|
|
|
case 256:
|
|
|
-#if defined(GCN) || defined(CDNA)
|
|
|
- return ncols <= 16 ? 64 : 128;
|
|
|
-#else
|
|
|
- return ncols <= 16 ? 64 : 256;
|
|
|
-#endif // defined(GCN) || defined(CDNA)
|
|
|
+ return 128;
|
|
|
default:
|
|
|
return -1;
|
|
|
}
|
|
|
@@ -120,9 +125,8 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
|
|
case 64:
|
|
|
return 64;
|
|
|
case 128:
|
|
|
- return ncols <= 16 ? 128 : 64;
|
|
|
case 256:
|
|
|
- return ncols <= 16 ? 64 : 128;
|
|
|
+ return 128;
|
|
|
default:
|
|
|
return -1;
|
|
|
}
|
|
|
@@ -142,12 +146,27 @@ static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols
|
|
|
GGML_UNUSED_VARS(ncols, warp_size);
|
|
|
}
|
|
|
|
|
|
-template<int D, int ncols, bool use_logit_softcap> // D == head size
|
|
|
-#ifdef GGML_USE_HIP
|
|
|
-__launch_bounds__(FATTN_TILE_NTHREADS, 1)
|
|
|
+static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
|
|
|
+ return 256;
|
|
|
+ GGML_UNUSED_VARS(cc, ncols);
|
|
|
+}
|
|
|
+
|
|
|
+static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
|
|
|
+ return 256;
|
|
|
+ GGML_UNUSED(ncols);
|
|
|
+}
|
|
|
+
|
|
|
+static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
|
|
|
+#ifdef RDNA
|
|
|
+ return 3;
|
|
|
#else
|
|
|
-__launch_bounds__(FATTN_TILE_NTHREADS, 2)
|
|
|
-#endif // GGML_USE_HIP
|
|
|
+ return ncols <= 16 ? 3 : 2;
|
|
|
+#endif // RDNA
|
|
|
+ GGML_UNUSED(ncols);
|
|
|
+}
|
|
|
+
|
|
|
+template<int D, int ncols, bool use_logit_softcap> // D == head size
|
|
|
+__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
|
|
|
static __global__ void flash_attn_tile(
|
|
|
const char * __restrict__ Q,
|
|
|
const char * __restrict__ K,
|
|
|
@@ -193,7 +212,7 @@ static __global__ void flash_attn_tile(
|
|
|
}
|
|
|
|
|
|
constexpr int warp_size = 32;
|
|
|
- constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size;
|
|
|
+ constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
|
|
|
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
|
|
|
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
|
|
|
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
|
|
|
@@ -206,90 +225,126 @@ static __global__ void flash_attn_tile(
|
|
|
const int sequence = blockIdx.z / ne02;
|
|
|
const int head = blockIdx.z - sequence*ne02;
|
|
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
|
- const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
|
- const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
|
- const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
|
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
|
- const float * sinksf = (const float *) (sinks);
|
|
|
+ const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
|
+ const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
|
+ const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
|
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
|
+ const float * sinksf = (const float *) (sinks);
|
|
|
|
|
|
const int stride_KV2 = nb11 / sizeof(half2);
|
|
|
|
|
|
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
|
|
|
|
|
-#if defined(GGML_USE_HIP)
|
|
|
- constexpr int cpy_nb = 16;
|
|
|
-#else
|
|
|
- constexpr int cpy_nb = 8;
|
|
|
-#endif // defined(GGML_USE_HIP) && defined(GCN)
|
|
|
+ constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
|
|
|
constexpr int cpy_ne = cpy_nb / 4;
|
|
|
|
|
|
- __shared__ float KQ[ncols][kq_stride];
|
|
|
+ constexpr int cpw = ncols/nwarps; // cols per warp
|
|
|
+
|
|
|
+ // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
|
|
|
+ // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
+ constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
|
|
|
+
|
|
|
+ __shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
|
|
__shared__ half2 Q_tmp[ncols][D/2];
|
|
|
- __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
|
|
- half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
|
|
+ __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
|
|
+ half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
|
|
#else
|
|
|
+ constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
|
|
|
+
|
|
|
+ __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
|
|
|
__shared__ float Q_tmp[ncols][D];
|
|
|
- __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
|
|
- float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
|
|
|
- float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
|
|
+ __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
|
|
|
+ float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
+ static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
|
|
|
|
|
|
-
|
|
|
- float kqmax[ncols/nwarps];
|
|
|
+ float KQ_max[cpw];
|
|
|
#pragma unroll
|
|
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
- kqmax[j0/nwarps] = -FLT_MAX/2.0f;
|
|
|
+ KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
|
|
|
}
|
|
|
- float kqsum[ncols/nwarps] = {0.0f};
|
|
|
+ float KQ_sum[cpw] = {0.0f};
|
|
|
|
|
|
+ // Load Q data, convert to FP16 if fast.
|
|
|
#pragma unroll
|
|
|
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
- const int j = j0 + threadIdx.y;
|
|
|
+ for (int j0 = 0; j0 < cpw; ++j0) {
|
|
|
+ const int j = j0 + threadIdx.y*cpw;
|
|
|
+
|
|
|
+ constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
|
|
|
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
- const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f);
|
|
|
+ for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
|
|
+ float tmp_f[cpy_ne_D] = {0.0f};
|
|
|
+ if (ic0 + j < ne01) {
|
|
|
+ ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
|
|
|
+ }
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
|
|
+ tmp_f[i1] *= scale;
|
|
|
+ }
|
|
|
+
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale);
|
|
|
+ half2 tmp_h2[cpy_ne_D/2];
|
|
|
+#pragma unroll
|
|
|
+ for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
|
|
+ tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
|
|
+ }
|
|
|
+ ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
|
|
|
#else
|
|
|
- Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale;
|
|
|
- Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale;
|
|
|
+ ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
}
|
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
+ // Main loop over KV cache:
|
|
|
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
|
|
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
|
|
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
|
|
|
|
|
- float kqmax_new[ncols/nwarps];
|
|
|
+ float KQ_max_new[cpw];
|
|
|
#pragma unroll
|
|
|
- for (int j = 0; j < ncols/nwarps; ++j) {
|
|
|
- kqmax_new[j] = kqmax[j];
|
|
|
+ for (int j = 0; j < cpw; ++j) {
|
|
|
+ KQ_max_new[j] = KQ_max[j];
|
|
|
}
|
|
|
|
|
|
- float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}};
|
|
|
+ float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
|
|
|
|
|
|
+ // KQ = K @ Q matrix multiplication:
|
|
|
#pragma unroll
|
|
|
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
|
|
|
#pragma unroll
|
|
|
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
|
|
|
const int i_KQ = i_KQ_0 + threadIdx.y;
|
|
|
|
|
|
-#pragma unroll
|
|
|
- for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
|
|
|
- const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2;
|
|
|
+ constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
|
|
|
+#pragma unroll
|
|
|
+ for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
|
|
+ ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
|
|
|
+ &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
|
|
|
+ &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
|
|
|
+ }
|
|
|
#else
|
|
|
- const float2 tmp_f2 = __half22float2(tmp_h2);
|
|
|
- KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
|
|
|
- KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
+ constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
|
|
|
+#pragma unroll
|
|
|
+ for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
|
|
|
+ half2 tmp_h2[cpy_ne_kqnb/2];
|
|
|
+ ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
|
|
+ tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
|
|
|
+
|
|
|
+ float2 tmp_f2[cpy_ne_kqnb/2];
|
|
|
+#pragma unroll
|
|
|
+ for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
|
|
|
+ tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
|
|
|
+ }
|
|
|
+ ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
|
|
+ &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
|
|
|
}
|
|
|
+#endif // FAST_FP16_AVAILABLE
|
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
|
@@ -298,12 +353,12 @@ static __global__ void flash_attn_tile(
|
|
|
#pragma unroll
|
|
|
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
|
|
|
half2 K_k[kq_stride/warp_size][cpy_ne];
|
|
|
- half2 Q_k[ncols/nwarps][cpy_ne];
|
|
|
+ half2 Q_k[cpw][cpy_ne];
|
|
|
#else
|
|
|
#pragma unroll
|
|
|
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
|
|
|
float K_k[kq_stride/warp_size][cpy_ne];
|
|
|
- float Q_k[ncols/nwarps][cpy_ne];
|
|
|
+ float Q_k[cpw][cpy_ne];
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
#pragma unroll
|
|
|
@@ -311,29 +366,29 @@ static __global__ void flash_attn_tile(
|
|
|
const int i_KQ = i_KQ_0 + threadIdx.x;
|
|
|
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
|
|
+ ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
|
|
|
#else
|
|
|
- ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
|
|
+ ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
}
|
|
|
#pragma unroll
|
|
|
- for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
|
|
- const int j_KQ = j_KQ_0 + threadIdx.y;
|
|
|
+ for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
|
|
+ const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
|
|
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
|
|
+ ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
|
|
|
#else
|
|
|
- ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
|
|
+ ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
}
|
|
|
|
|
|
#pragma unroll
|
|
|
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
|
|
#pragma unroll
|
|
|
- for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
|
|
+ for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
|
|
#pragma unroll
|
|
|
for (int k = 0; k < cpy_ne; ++k) {
|
|
|
- ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
|
|
|
+ ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -344,104 +399,77 @@ static __global__ void flash_attn_tile(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // Apply logit softcap, mask, update KQ_max:
|
|
|
#pragma unroll
|
|
|
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
|
|
|
const int i_KQ = i_KQ_0 + threadIdx.x;
|
|
|
|
|
|
#pragma unroll
|
|
|
- for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
|
|
|
- const int j_KQ = j_KQ_0 + threadIdx.y;
|
|
|
+ for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
|
|
|
+ const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
|
|
|
|
|
|
if (use_logit_softcap) {
|
|
|
- sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
|
|
|
+ KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
|
|
}
|
|
|
|
|
|
- sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
|
|
-
|
|
|
- kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
|
|
|
+ KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
|
|
|
|
|
|
- KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps];
|
|
|
+ KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
+ // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
|
|
|
#pragma unroll
|
|
|
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
- const int j = j0 + threadIdx.y;
|
|
|
-
|
|
|
- kqmax_new[j0/nwarps] = warp_reduce_max<warp_size>(kqmax_new[j0/nwarps]);
|
|
|
- const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
|
|
|
- kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
|
|
|
-
|
|
|
- float kqsum_add = 0.0f;
|
|
|
- if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) {
|
|
|
-#pragma unroll
|
|
|
- for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) {
|
|
|
- const int i = i0 + 4*threadIdx.x;
|
|
|
-
|
|
|
- float4 val = *(const float4 *) &KQ[j][i];
|
|
|
- val.x = expf(val.x - kqmax[j0/nwarps]);
|
|
|
- val.y = expf(val.y - kqmax[j0/nwarps]);
|
|
|
- val.z = expf(val.z - kqmax[j0/nwarps]);
|
|
|
- val.w = expf(val.w - kqmax[j0/nwarps]);
|
|
|
- kqsum_add += val.x + val.y + val.z + val.w;
|
|
|
-
|
|
|
+ for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)};
|
|
|
- ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
|
|
|
+ half tmp[kq_stride/warp_size][softmax_iter_j];
|
|
|
#else
|
|
|
- ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
|
|
|
+ float tmp[kq_stride/warp_size][softmax_iter_j];
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
- }
|
|
|
- } else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) {
|
|
|
+
|
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) {
|
|
|
- const int i = i0 + 2*threadIdx.x;
|
|
|
+ for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
|
|
+ KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
|
|
|
+ const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
|
|
|
+ KQ_max[j0+j1] = KQ_max_new[j0+j1];
|
|
|
|
|
|
- float2 val = *(const float2 *) &KQ[j][i];
|
|
|
- val.x = expf(val.x - kqmax[j0/nwarps]);
|
|
|
- val.y = expf(val.y - kqmax[j0/nwarps]);
|
|
|
- kqsum_add += val.x + val.y;
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
- const half2 tmp = make_half2(val.x, val.y);
|
|
|
- ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
|
|
|
-#else
|
|
|
- ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
- }
|
|
|
- } else {
|
|
|
+ float KQ_sum_add = 0.0f;
|
|
|
+#pragma unroll
|
|
|
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
|
|
- const int i = i0 + threadIdx.x;
|
|
|
+ const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
|
|
|
+ KQ_sum_add += val;
|
|
|
+ tmp[i0/warp_size][j1] = val;
|
|
|
+ }
|
|
|
+ KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
|
|
|
|
|
|
- const float diff = KQ[j][i] - kqmax[j0/nwarps];
|
|
|
- const float val = expf(diff);
|
|
|
- kqsum_add += val;
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- ((half *) KQ[j])[i] = val;
|
|
|
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
+ VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
|
|
|
+ }
|
|
|
#else
|
|
|
- KQ[j][i] = val;
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
+ VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
|
|
|
+ VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
|
|
|
}
|
|
|
+#endif // FAST_FP16_AVAILABLE
|
|
|
}
|
|
|
- kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
|
|
|
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
- VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
|
|
|
- }
|
|
|
-#else
|
|
|
-#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
- VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
|
|
|
- VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
|
|
|
+ for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
|
|
|
+ const int i = i0 + threadIdx.x;
|
|
|
+
|
|
|
+ ggml_cuda_memcpy_1<sizeof(tmp[0])>(
|
|
|
+ KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
|
|
|
}
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
}
|
|
|
|
|
|
- constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D;
|
|
|
+ // VKQ = V @ KQ matrix multiplication:
|
|
|
+ constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
|
|
|
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
|
|
|
#pragma unroll
|
|
|
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
|
|
|
@@ -449,65 +477,96 @@ static __global__ void flash_attn_tile(
|
|
|
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
|
|
|
const int k_tile = k1 + threadIdx.y;
|
|
|
|
|
|
-#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
- const int i = i0 + threadIdx.x;
|
|
|
-
|
|
|
- const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i];
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- KV_tmp_h2[k_tile*(D/2) + i] = tmp;
|
|
|
+ constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
|
|
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
|
|
+ &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
|
|
|
+ &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
|
|
|
+ }
|
|
|
#else
|
|
|
- KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp);
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
+ constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
|
|
+ half2 tmp_h2[cpy_ne_D/2];
|
|
|
+ ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
|
|
+ tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
|
|
|
+
|
|
|
+ float2 tmp_f2[cpy_ne_D/2];
|
|
|
+#pragma unroll
|
|
|
+ for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
|
|
|
+ tmp_f2[i1] = __half22float2(tmp_h2[i1]);
|
|
|
+ }
|
|
|
+ ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
|
|
|
+ &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
|
|
|
}
|
|
|
+#endif // FAST_FP16_AVAILABLE
|
|
|
}
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
+#ifdef FAST_FP16_AVAILABLE
|
|
|
#pragma unroll
|
|
|
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
half2 V_k[(D/2)/warp_size];
|
|
|
- half2 KQ_k[ncols/nwarps];
|
|
|
-#else
|
|
|
- float2 V_k[(D/2)/warp_size];
|
|
|
- float KQ_k[ncols/nwarps];
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
+ half2 KQ_k[cpw];
|
|
|
|
|
|
+ constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
|
|
#pragma unroll
|
|
|
- for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
- const int i = i0 + threadIdx.x;
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
|
|
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
|
|
|
+ }
|
|
|
+#pragma unroll
|
|
|
+ for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
|
|
+ const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
|
|
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
- V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i];
|
|
|
+ half tmp[softmax_iter_j];
|
|
|
+ ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
|
|
|
+ &tmp, KQ[j][k0 + k1]);
|
|
|
+#pragma unroll
|
|
|
+ for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
|
|
|
+ KQ_k[j0+j1] = __half2half2(tmp[j1]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
+#pragma unroll
|
|
|
+ for (int j0 = 0; j0 < cpw; ++j0) {
|
|
|
+ VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
#else
|
|
|
- V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i];
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
+#pragma unroll
|
|
|
+ for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
|
|
|
+ float2 V_k[(D/2)/warp_size];
|
|
|
+ float KQ_k[cpw];
|
|
|
+
|
|
|
+ constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
|
|
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
|
|
|
}
|
|
|
#pragma unroll
|
|
|
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
- const int j = j0 + threadIdx.y;
|
|
|
+ for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
|
|
|
+ const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
|
|
|
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
- KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]);
|
|
|
-#else
|
|
|
- KQ_k[j0/nwarps] = KQ[j][k0 + k1];
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
+ ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
|
|
|
+ &KQ_k[j0], KQ[j][k0 + k1]);
|
|
|
}
|
|
|
|
|
|
#pragma unroll
|
|
|
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
#pragma unroll
|
|
|
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
- VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps];
|
|
|
-#else
|
|
|
- VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps];
|
|
|
- VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps];
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
+ for (int j0 = 0; j0 < cpw; ++j0) {
|
|
|
+ VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
|
|
|
+ VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
__syncthreads();
|
|
|
}
|
|
|
@@ -519,69 +578,92 @@ static __global__ void flash_attn_tile(
|
|
|
const float sink = sinksf[head];
|
|
|
|
|
|
#pragma unroll
|
|
|
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
- float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
|
|
|
- kqmax_new_j = warp_reduce_max<warp_size>(kqmax_new_j);
|
|
|
+ for (int j0 = 0; j0 < cpw; ++j0) {
|
|
|
+ float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
|
|
|
+ KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
|
|
|
|
|
|
- const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
|
|
|
- kqmax[j0/nwarps] = kqmax_new_j;
|
|
|
+ const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
|
|
|
+ KQ_max[j0] = KQ_max_new_j;
|
|
|
|
|
|
- const float val = expf(sink - kqmax[j0/nwarps]);
|
|
|
- kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
|
|
|
+ const float val = expf(sink - KQ_max[j0]);
|
|
|
+ KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
|
|
|
if (threadIdx.x == 0) {
|
|
|
- kqsum[j0/nwarps] += val;
|
|
|
+ KQ_sum[j0] += val;
|
|
|
}
|
|
|
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
|
#pragma unroll
|
|
|
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
- VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
|
|
|
+ VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
|
|
|
}
|
|
|
#else
|
|
|
#pragma unroll
|
|
|
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
- VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
|
|
|
- VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
|
|
|
+ VKQ[j0][i0/warp_size].x *= KQ_max_scale;
|
|
|
+ VKQ[j0][i0/warp_size].y *= KQ_max_scale;
|
|
|
}
|
|
|
#endif // FAST_FP16_AVAILABLE
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- float2 * dst2 = (float2 *) dst;
|
|
|
+#pragma unroll
|
|
|
+ for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
|
|
+ KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
|
|
|
+ }
|
|
|
+ if (gridDim.y == 1) {
|
|
|
+#pragma unroll
|
|
|
+ for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
|
|
+#ifdef FAST_FP16_AVAILABLE
|
|
|
+ const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < (D/2)/warp_size; ++i) {
|
|
|
+ VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
|
|
|
+ }
|
|
|
+#else
|
|
|
+ const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
|
|
|
+#pragma unroll
|
|
|
+ for (int i = 0; i < (D/2)/warp_size; ++i) {
|
|
|
+ VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
|
|
|
+ VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
|
|
|
+ }
|
|
|
+#endif // FAST_FP16_AVAILABLE
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
+ // Write back results:
|
|
|
#pragma unroll
|
|
|
- for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
|
|
- const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
|
|
+ for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
|
|
|
+ const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
|
|
|
|
|
|
if (ic0 + j_VKQ >= ne01) {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
|
|
- kqsum_j = warp_reduce_sum<warp_size>(kqsum_j);
|
|
|
-
|
|
|
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
|
|
|
|
|
-#pragma unroll
|
|
|
- for (int i00 = 0; i00 < D/2; i00 += warp_size) {
|
|
|
- const int i0 = i00 + threadIdx.x;
|
|
|
-
|
|
|
#ifdef FAST_FP16_AVAILABLE
|
|
|
- float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]);
|
|
|
-#else
|
|
|
- float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size];
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
-
|
|
|
- if (gridDim.y == 1) {
|
|
|
- dst_val.x /= kqsum_j;
|
|
|
- dst_val.y /= kqsum_j;
|
|
|
+ constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
|
|
|
+ float2 tmp[cpy_ne_D];
|
|
|
+#pragma unroll
|
|
|
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
|
|
|
+ tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
|
|
|
}
|
|
|
- dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
|
|
+ ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
|
|
|
}
|
|
|
+#else
|
|
|
+ constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
|
|
|
+ ggml_cuda_memcpy_1<cpy_ne_D*4>(
|
|
|
+ &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
|
|
|
+ }
|
|
|
+#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
if (gridDim.y != 1 && threadIdx.x == 0) {
|
|
|
- dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
|
|
+ dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
|
|
|
}
|
|
|
}
|
|
|
#else
|
|
|
@@ -602,15 +684,29 @@ template <int D, bool use_logit_softcap>
|
|
|
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
const ggml_tensor * Q = dst->src[0];
|
|
|
|
|
|
- const int id = ggml_cuda_get_device();
|
|
|
- const int cc = ggml_cuda_info().devices[id].cc;
|
|
|
- const int warp_size = 32;
|
|
|
- const int nwarps = FATTN_TILE_NTHREADS / warp_size;
|
|
|
+ const int id = ggml_cuda_get_device();
|
|
|
+ const int cc = ggml_cuda_info().devices[id].cc;
|
|
|
+ const int warp_size = 32;
|
|
|
|
|
|
constexpr size_t nbytes_shared = 0;
|
|
|
|
|
|
+#ifdef GGML_USE_HIP
|
|
|
+ if constexpr (D <= 128) {
|
|
|
+ if (Q->ne[1] > 32) {
|
|
|
+ constexpr int cols_per_block = 64;
|
|
|
+ const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
|
|
+ fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
|
|
+ const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
|
|
+ launch_fattn<D, cols_per_block, 1>
|
|
|
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ }
|
|
|
+#endif // GGML_USE_HIP
|
|
|
+
|
|
|
if (Q->ne[1] > 16) {
|
|
|
constexpr int cols_per_block = 32;
|
|
|
+ const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
|
|
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
|
|
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
|
|
launch_fattn<D, cols_per_block, 1>
|
|
|
@@ -619,6 +715,7 @@ static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml
|
|
|
}
|
|
|
|
|
|
constexpr int cols_per_block = 16;
|
|
|
+ const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
|
|
|
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
|
|
|
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
|
|
|
launch_fattn<D, cols_per_block, 1>
|