|
@@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
|
|
|
|
|
|
|
|
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
|
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
|
|
|
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
|
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
|
|
|
#else
|
|
#else
|
|
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
|
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#endif // V_DOT2_F32_F16_AVAILABLE
|
|
|
|
|
|
|
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
|
|
|
|
|
|
@@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
|
|
|
|
|
|
|
|
constexpr int ne_KQ = ncols*D;
|
|
constexpr int ne_KQ = ncols*D;
|
|
|
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
|
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
|
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
|
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
|
|
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
|
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
|
|
#else
|
|
#else
|
|
|
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
|
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
|
|
|
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
|
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#endif // V_DOT2_F32_F16_AVAILABLE
|
|
|
|
|
|
|
|
float KQ_max[ncols];
|
|
float KQ_max[ncols];
|
|
|
float KQ_sum[ncols];
|
|
float KQ_sum[ncols];
|
|
@@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
|
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
|
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
|
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
|
|
|
#else
|
|
#else
|
|
|
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
|
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#endif // V_DOT2_F32_F16_AVAILABLE
|
|
|
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
|
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
|
|
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
|
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
|
|
|
if constexpr (Q_q8_1) {
|
|
if constexpr (Q_q8_1) {
|
|
@@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
|
|
|
|
|
__syncthreads();
|
|
__syncthreads();
|
|
|
} else {
|
|
} else {
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
|
const half2 scale_h2 = make_half2(scale, scale);
|
|
const half2 scale_h2 = make_half2(scale, scale);
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int j = 0; j < ncols; ++j) {
|
|
for (int j = 0; j < ncols; ++j) {
|
|
@@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
Q_reg[j][k].y *= scale;
|
|
Q_reg[j][k].y *= scale;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#endif // V_DOT2_F32_F16_AVAILABLE
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
|
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
|
|
@@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
|
|
|
KQ[j*nthreads + tid] = KQ_reg[j];
|
|
KQ[j*nthreads + tid] = KQ_reg[j];
|
|
|
|
|
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
|
@@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
|
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
|
|
}
|
|
}
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#endif // V_DOT2_F32_F16_AVAILABLE
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
#ifndef GGML_USE_HIP
|
|
#ifndef GGML_USE_HIP
|
|
@@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
|
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
|
|
|
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
|
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
|
|
|
|
|
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
|
half2 KQ_k[ncols];
|
|
half2 KQ_k[ncols];
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int j = 0; j < ncols; ++j) {
|
|
for (int j = 0; j < ncols; ++j) {
|
|
@@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#endif // V_DOT2_F32_F16_AVAILABLE
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
|
|
|
|
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
|
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
|
|
|
|
|
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
|
|
|
#pragma unroll
|
|
#pragma unroll
|
|
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
|
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
|
|
@@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
|
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
|
|
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
|
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
|
|
|
}
|
|
}
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#endif // V_DOT2_F32_F16_AVAILABLE
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
|
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
|
|
|
KQ_max[j_VKQ] = kqmax_new;
|
|
KQ_max[j_VKQ] = kqmax_new;
|
|
|
|
|
|
|
|
-#ifdef FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#ifdef V_DOT2_F32_F16_AVAILABLE
|
|
|
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
|
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
|
|
|
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
|
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
|
|
|
|
|
|
|
@@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
|
|
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
|
|
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
|
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
|
|
|
}
|
|
}
|
|
|
-#endif // FAST_FP16_AVAILABLE
|
|
|
|
|
|
|
+#endif // V_DOT2_F32_F16_AVAILABLE
|
|
|
|
|
|
|
|
KQ_sum[j_VKQ] *= kqmax_scale;
|
|
KQ_sum[j_VKQ] *= kqmax_scale;
|
|
|
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|
|
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);
|