|
|
@@ -564,6 +564,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
|
|
|
for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
|
|
|
const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x;
|
|
|
|
|
|
+#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
|
|
|
+ // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
|
|
|
+ // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
|
|
|
+ KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;
|
|
|
+#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
|
|
|
+
|
|
|
if (use_logit_softcap) {
|
|
|
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
|
|
|
}
|
|
|
@@ -858,6 +864,11 @@ static __global__ void flash_attn_tile(
|
|
|
#pragma unroll
|
|
|
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
|
|
|
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
|
|
|
+#if defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
|
|
|
+ // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
|
|
|
+ // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
|
|
|
+ tmp_h2[i1/2] *= make_half2(0.25f, 0.25f);
|
|
|
+#endif // defined(FAST_FP16_AVAILABLE) && !defined(V_DOT2_F32_F16_AVAILABLE)
|
|
|
}
|
|
|
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
|
|
|
&Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)],
|