|
|
@@ -82,11 +82,12 @@ static __global__ void flash_attn_ext_f16(
|
|
|
const int sequence = blockIdx.z / ne02;
|
|
|
const int head = blockIdx.z - sequence*ne02;
|
|
|
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
|
|
- const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
|
- const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
|
- const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
|
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
|
- const half2 * mask2 = (const half2 *) maskh;
|
|
|
+ const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
|
|
+ const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
|
|
+ const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
|
|
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
|
|
+ const half2 * mask2 = (const half2 *) maskh;
|
|
|
+ const float * sinksf = (const float *) sinks;
|
|
|
|
|
|
const int stride_Q = nb01 / sizeof(float);
|
|
|
const int stride_KV = nb11 / sizeof(half);
|
|
|
@@ -381,6 +382,53 @@ static __global__ void flash_attn_ext_f16(
|
|
|
__syncthreads();
|
|
|
}
|
|
|
|
|
|
+ // Apply attention sinks
|
|
|
+ if (sinksf && blockIdx.y == 0) {
|
|
|
+ const float sinkf = sinksf[head];
|
|
|
+ const half sinkh = __float2half(sinkf);
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
+ const int j = j0 + threadIdx.y;
|
|
|
+
|
|
|
+ if (std::is_same<KQ_acc_t, float>::value) {
|
|
|
+ float kqmax_new = fmaxf(KQ_max_f[j0/nwarps], sinkf);
|
|
|
+
|
|
|
+ const float KQ_max_scale = expf(KQ_max_f[j0/nwarps] - kqmax_new);
|
|
|
+ KQ_max_f[j0/nwarps] = kqmax_new;
|
|
|
+
|
|
|
+ KQ_rowsum_f[j0/nwarps] = KQ_rowsum_f[j0/nwarps] * KQ_max_scale + expf(sinkf - KQ_max_f[j0/nwarps]);
|
|
|
+
|
|
|
+ const half2 scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
+ const int i = i0 + threadIdx.x;
|
|
|
+ if (i0 + warp_size > D/2 && i >= D/2) break;
|
|
|
+ VKQ2[j*(D_padded/2) + i] *= scale_h2;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ half kqmax_old = __low2half(KQ_max_h2[j0/nwarps]);
|
|
|
+ half kqmax_new = fmaxf(kqmax_old, sinkh);
|
|
|
+ KQ_max_h2[j0/nwarps] = __half2half2(kqmax_new);
|
|
|
+
|
|
|
+ const half KQ_max_scale_h = hexp(kqmax_old - kqmax_new);
|
|
|
+ const half2 KQ_max_scale = __half2half2(KQ_max_scale_h);
|
|
|
+
|
|
|
+ KQ_rowsum_h2[j0/nwarps] = KQ_rowsum_h2[j0/nwarps] * KQ_max_scale;
|
|
|
+ const half val = hexp(sinkh - kqmax_new);
|
|
|
+ KQ_rowsum_h2[j0/nwarps].x = __hadd(KQ_rowsum_h2[j0/nwarps].x, val);
|
|
|
+
|
|
|
+#pragma unroll
|
|
|
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
|
|
|
+ const int i = i0 + threadIdx.x;
|
|
|
+ if (i0 + warp_size > D/2 && i >= D/2) break;
|
|
|
+ VKQ2[j*(D_padded/2) + i] *= KQ_max_scale;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ __syncthreads();
|
|
|
+ }
|
|
|
#pragma unroll
|
|
|
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
|
|
|
const int j_VKQ = j0 + threadIdx.y;
|