| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593 |
- #include "common.cuh"
- #include "fattn-common.cuh"
- static int ggml_cuda_fattn_vec_get_nthreads_host(const int cc) {
- return 128;
- GGML_UNUSED(cc);
- }
- static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() {
- return 128;
- }
- // Currenlty llvm with the amdgcn target dose not support unrolling loops
- // that contain a break that can not be resolved at compile time.
- #ifdef __clang__
- #pragma clang diagnostic push
- #pragma clang diagnostic ignored "-Wpass-failed"
- #endif // __clang__
- template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
- __launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1)
- static __global__ void flash_attn_ext_vec(
- const char * __restrict__ Q,
- const char * __restrict__ K,
- const char * __restrict__ V,
- const char * __restrict__ mask,
- const char * __restrict__ sinks,
- const int * __restrict__ KV_max,
- float * __restrict__ dst,
- float2 * __restrict__ dst_meta,
- const float scale,
- const float max_bias,
- const float m0,
- const float m1,
- const uint32_t n_head_log2,
- const float logit_softcap,
- const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
- const int32_t nb01, const int32_t nb02, const int32_t nb03,
- const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
- const int32_t nb11, const int32_t nb12, const int64_t nb13,
- const int32_t nb21, const int32_t nb22, const int64_t nb23,
- const int32_t ne31, const int32_t ne32, const int32_t ne33,
- const int32_t nb31, const int32_t nb32, const int64_t nb33) {
- #ifdef FLASH_ATTN_AVAILABLE
- // Skip unused kernel variants for faster compilation:
- if (use_logit_softcap && !(D == 128 || D == 256)) {
- GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
- max_bias, m0, m1, n_head_log2, logit_softcap,
- ne00, ne01, ne02, ne03,
- nb01, nb02, nb03,
- ne10, ne11, ne12, ne13,
- nb11, nb12, nb13,
- nb21, nb22, nb23,
- ne31, ne32, ne33,
- nb31, nb32, nb33);
- NO_DEVICE_CODE;
- return;
- }
- //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
- constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
- constexpr int cpy_ne = cpy_nb / 4;
- #ifdef GGML_USE_HIP
- #ifdef RDNA
- constexpr int nthreads_KQ_q = 2;
- #else
- constexpr int nthreads_KQ_q = 4;
- #endif // RDNA
- constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
- #else
- constexpr int nthreads_KQ_q = (D/4 < 32 ? D/4 : 32);
- constexpr int nthreads_V_q = (D/4 < 32 ? D/4 : 32);
- #endif // GGML_USE_HIP
- constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device();
- constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
- constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
- static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K");
- static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V");
- constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
- constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V;
- constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
- constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
- #ifdef FAST_FP16_AVAILABLE
- constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
- #else
- constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
- #endif // FAST_FP16_AVAILABLE
- const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
- const int sequence = blockIdx.z / ne02;
- const int head = blockIdx.z - sequence*ne02;
- const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
- Q += nb03*sequence + nb02* head + nb01*ic0;
- K += nb13*sequence + nb12*(head / gqa_ratio);
- V += nb23*sequence + nb22*(head / gqa_ratio);
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
- const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
- static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
- constexpr int nwarps = nthreads / WARP_SIZE;
- const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
- __builtin_assume(tid < nthreads);
- constexpr int ne_KQ = ncols*D;
- constexpr int ne_combine = nwarps*V_cols_per_iter*D;
- #ifdef FAST_FP16_AVAILABLE
- half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
- __shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
- #else
- float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
- __shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
- #endif // FAST_FP16_AVAILABLE
- float KQ_max[ncols];
- float KQ_sum[ncols];
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- KQ_max[j] = -FLT_MAX/2.0f;
- KQ_sum[j] = 0.0f;
- }
- // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
- #ifdef FAST_FP16_AVAILABLE
- half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
- #else
- float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
- #endif // FAST_FP16_AVAILABLE
- int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
- float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
- if constexpr (Q_q8_1) {
- #pragma unroll
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
- if (j0 + nwarps > ncols && j >= ncols) {
- break;
- }
- // Reuse KQ as temporary storage for converting Q to q8_1:
- int * tmp_q_i32 = (int *) &KQ[j*D];
- float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
- // Set memory to zero if out of bounds:
- if (ncols > 1 && ic0 + j >= ne01) {
- #pragma unroll
- for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
- const int i = i0 + threadIdx.x;
- if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
- tmp_q_i32[i] = 0;
- }
- }
- if (threadIdx.x < D/QK8_1) {
- tmp_q_ds[threadIdx.x] = make_float2(0.0f, 0.0f);
- }
- } else {
- const float * Q_f = (const float *) (Q + j*nb01);
- constexpr int nthreads_quantize = D/sizeof(int) < WARP_SIZE ? D/sizeof(int) : WARP_SIZE;
- #pragma unroll
- for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
- quantize_q8_1_to_shared<float2, nthreads_quantize>
- (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
- }
- }
- }
- __syncthreads();
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- int * tmp_q_i32 = (int *) &KQ[j*D];
- float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
- #pragma unroll
- for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
- const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ);
- Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
- Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1];
- }
- }
- __syncthreads();
- } else {
- #ifdef FAST_FP16_AVAILABLE
- const half2 scale_h2 = make_half2(scale, scale);
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- const float2 * Q_j = (const float2 *) (Q + j*nb01);
- #pragma unroll
- for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
- const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
- float2 tmp[cpy_ne] = {{0.0f, 0.0f}};
- if (ncols == 1 || ic0 + j < ne01) {
- ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
- ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
- }
- #pragma unroll
- for (int i1 = 0; i1 < cpy_ne; ++i1) {
- Q_reg[j][i0/nthreads_KQ + i1] = make_half2(tmp[i1].x, tmp[i1].y);
- }
- }
- #pragma unroll
- for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
- Q_reg[j][k] *= scale_h2;
- }
- }
- #else
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- const float2 * Q_j = (const float2 *) (Q + j*nb01);
- #pragma unroll
- for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
- const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne;
- if (ncols == 1 || ic0 + j < ne01) {
- ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
- ggml_cuda_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
- }
- }
- #pragma unroll
- for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
- Q_reg[j][k].x *= scale;
- Q_reg[j][k].y *= scale;
- }
- }
- #endif // FAST_FP16_AVAILABLE
- }
- const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
- K += blockIdx.y*nthreads * nb11;
- V += blockIdx.y*nthreads * nb21;
- maskh += blockIdx.y*nthreads;
- for (int k_VKQ_0 = blockIdx.y*nthreads; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nthreads,
- // Increment pointers after each loop:
- K += gridDim.y*nthreads*nb11, V += gridDim.y*nthreads*nb21, maskh += gridDim.y*nthreads) {
- // Calculate KQ tile and keep track of new maximum KQ values:
- float KQ_reg[ncols]; // KQ in registers.
- float KQ_max_new[ncols];
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- KQ_max_new[j] = KQ_max[j];
- }
- #pragma unroll
- for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
- const int i_KQ = threadIdx.y*WARP_SIZE + (nthreads_KQ == WARP_SIZE ? 0 : (threadIdx.x & ~(nthreads_KQ-1))) + i_KQ_0;
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
- sum = warp_reduce_sum<nthreads_KQ>(sum);
- if (use_logit_softcap) {
- sum = logit_softcap*tanhf(sum);
- }
- if (mask) {
- sum += slope*__half2float(maskh[j*ne11 + i_KQ]);
- }
- KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
- if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
- KQ_reg[j] = sum;
- }
- }
- }
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- #pragma unroll
- for (int offset = nthreads_KQ; offset < WARP_SIZE; offset <<= 1) {
- KQ_max_new[j] = fmaxf(KQ_max_new[j], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[j], offset, WARP_SIZE));
- }
- const float KQ_max_scale = expf(KQ_max[j] - KQ_max_new[j]);
- KQ_max[j] = KQ_max_new[j];
- KQ_reg[j] = expf(KQ_reg[j] - KQ_max[j]);
- KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
- KQ[j*nthreads + tid] = KQ_reg[j];
- #ifdef FAST_FP16_AVAILABLE
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
- VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
- }
- #else
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
- VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
- VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
- }
- #endif // FAST_FP16_AVAILABLE
- }
- #ifndef GGML_USE_HIP
- __syncwarp();
- #endif // GGML_USE_HIP
- #pragma unroll
- for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
- const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
- #ifdef FAST_FP16_AVAILABLE
- half2 KQ_k[ncols];
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- KQ_k[j] = __half2half2(KQ[j*nthreads + k]);
- }
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
- half2 tmp[V_rows_per_thread/2];
- dequantize_V(V + k*nb21, tmp,
- 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
- #pragma unroll
- for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
- }
- }
- }
- #else
- float KQ_k[ncols];
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- KQ_k[j] = KQ[j*nthreads + k];
- }
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
- float2 tmp[V_rows_per_thread/2];
- dequantize_V(V + k*nb21, tmp,
- 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread);
- #pragma unroll
- for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x += tmp[i_VKQ_1].x*KQ_k[j];
- VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y += tmp[i_VKQ_1].y*KQ_k[j];
- }
- }
- }
- #endif // FAST_FP16_AVAILABLE
- }
- }
- if (sinks && blockIdx.y == 0) {
- const float sink = ((const float *) sinks)[head];
- #pragma unroll
- for (int j0 = 0; j0 < ncols; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
- if (j0 + nwarps > ncols && j >= ncols) {
- break;
- }
- const float kqmax_new_j = fmaxf(sink, KQ_max[j]);
- const float KQ_max_scale = expf(KQ_max[j] - kqmax_new_j);
- KQ_max[j] = kqmax_new_j;
- KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
- #ifdef FAST_FP16_AVAILABLE
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
- VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
- }
- #else
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
- VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
- VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
- }
- #endif // FAST_FP16_AVAILABLE
- }
- }
- __shared__ float KQ_max_shared[ncols][WARP_SIZE];
- __shared__ float KQ_sum_shared[ncols][WARP_SIZE];
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- if (threadIdx.y == 0) {
- KQ_max_shared[j][threadIdx.x] = -FLT_MAX/2.0f;
- KQ_sum_shared[j][threadIdx.x] = 0.0f;
- }
- }
- __syncthreads();
- #pragma unroll
- for (int j = 0; j < ncols; ++j) {
- if (threadIdx.x == 0) {
- KQ_max_shared[j][threadIdx.y] = KQ_max[j];
- }
- }
- __syncthreads();
- #pragma unroll
- for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
- if (ncols > 1 && ic0 + j_VKQ >= ne01) {
- break;
- }
- float kqmax_new = KQ_max_shared[j_VKQ][threadIdx.x];
- kqmax_new = warp_reduce_max(kqmax_new);
- const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
- KQ_max[j_VKQ] = kqmax_new;
- #ifdef FAST_FP16_AVAILABLE
- half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
- + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
- const half2 kqmax_scale_h2 = make_half2(kqmax_scale, kqmax_scale);
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
- VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
- }
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
- const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
- ggml_cuda_memcpy_1<V_rows_per_thread*sizeof(half)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
- }
- #else
- float2 * VKQ_tmp = (float2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
- + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
- VKQ[j_VKQ][i_VKQ_0/nthreads_V].x *= kqmax_scale;
- VKQ[j_VKQ][i_VKQ_0/nthreads_V].y *= kqmax_scale;
- }
- #pragma unroll
- for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
- const int i_VKQ = i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*(V_rows_per_thread/2);
- ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
- ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
- }
- #endif // FAST_FP16_AVAILABLE
- KQ_sum[j_VKQ] *= kqmax_scale;
- KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
- if (threadIdx.x == 0) {
- KQ_sum_shared[j_VKQ][threadIdx.y] = KQ_sum[j_VKQ];
- }
- __syncthreads();
- if (nthreads <= D || tid < D) {
- KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ][threadIdx.x];
- KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
- #pragma unroll
- for (int i0 = 0; i0 < D; i0 += nthreads) {
- float dst_val = 0;
- #pragma unroll
- for (int w = 0; w < nwarps; ++w) {
- #pragma unroll
- for (int v = 0; v < V_cols_per_iter; ++v) {
- dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
- }
- }
- if (gridDim.y == 1) {
- dst_val /= KQ_sum[j_VKQ];
- }
- dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + i0 + tid] = dst_val;
- }
- }
- if (j_VKQ < ncols-1) {
- __syncthreads();
- }
- }
- if (gridDim.y != 1 && tid < ncols && (ncols == 1 || ic0 + tid < ne01)) {
- dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]);
- }
- #else
- GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
- max_bias, m0, m1, n_head_log2, logit_softcap,
- ne00, ne01, ne02, ne03,
- nb01, nb02, nb03,
- ne10, ne11, ne12, ne13,
- nb11, nb12, nb13,
- nb21, nb22, nb23,
- ne31, ne32, ne33,
- nb31, nb32, nb33);
- NO_DEVICE_CODE;
- #endif // FLASH_ATTN_AVAILABLE
- }
- #ifdef __clang__
- #pragma clang diagnostic pop
- #endif // __clang__
- template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
- void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
- const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
- const int nwarps = nthreads / WARP_SIZE;
- fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
- constexpr bool need_f16_K = false;
- constexpr bool need_f16_V = false;
- constexpr size_t nbytes_shared = 0;
- launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
- }
- template <int D, ggml_type type_K, ggml_type type_V>
- void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * KQV = dst;
- const ggml_tensor * Q = dst->src[0];
- const ggml_tensor * K = dst->src[1];
- const ggml_tensor * V = dst->src[2];
- GGML_ASSERT(K->type == type_K);
- GGML_ASSERT(V->type == type_V);
- float logit_softcap;
- memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
- if (Q->ne[1] == 1) {
- constexpr int cols_per_block = 1;
- if (logit_softcap == 0.0f) {
- constexpr bool use_logit_softcap = false;
- ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
- } else {
- constexpr bool use_logit_softcap = true;
- ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
- }
- return;
- }
- constexpr int cols_per_block = 2;
- if (logit_softcap == 0.0f) {
- constexpr bool use_logit_softcap = false;
- ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
- } else {
- constexpr bool use_logit_softcap = true;
- ggml_cuda_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
- }
- }
- #define DECL_FATTN_VEC_CASE(D, type_K, type_V) \
- template void ggml_cuda_flash_attn_ext_vec_case \
- <D, type_K, type_V>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
- #define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
- extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
- extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
- extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
- extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
- extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
- extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
- EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
- EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
- EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
- EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
- EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
- EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
- EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
- EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
- EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
- EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
- EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
- EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
- EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
- EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
- EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
- EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
- EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
- EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
|