|
|
@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
|
|
|
typedef float (*vec_dot_KQ_f32_t)(
|
|
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
|
|
|
|
|
|
-template<typename T, int D>
|
|
|
+template<typename T, int D, int warp_size>
|
|
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
|
|
|
|
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
|
|
|
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
GGML_UNUSED(Q_v);
|
|
|
|
|
|
T sum = 0.0f;
|
|
|
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
|
|
return sum;
|
|
|
}
|
|
|
|
|
|
-template<typename T, int D>
|
|
|
+template<typename T, int D, int warp_size>
|
|
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
|
|
|
|
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
|
|
|
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
GGML_UNUSED(Q_v);
|
|
|
|
|
|
T sum = 0.0f;
|
|
|
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
|
|
return sum;
|
|
|
}
|
|
|
|
|
|
-template<typename T, int D>
|
|
|
+template<typename T, int D, int warp_size>
|
|
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
|
|
|
|
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
|
|
|
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
GGML_UNUSED(Q_v);
|
|
|
|
|
|
T sum = 0.0f;
|
|
|
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
|
|
return sum;
|
|
|
}
|
|
|
|
|
|
-template<typename T, int D>
|
|
|
+template<typename T, int D, int warp_size>
|
|
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
|
|
|
|
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
|
|
|
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
GGML_UNUSED(Q_v);
|
|
|
|
|
|
T sum = 0.0f;
|
|
|
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
|
|
return sum;
|
|
|
}
|
|
|
|
|
|
-template <typename T, int D>
|
|
|
+template <typename T, int D, int warp_size>
|
|
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
|
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
|
|
|
|
|
|
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
|
|
|
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
GGML_UNUSED(Q_v);
|
|
|
|
|
|
T sum = 0.0f;
|
|
|
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
|
|
|
return sum;
|
|
|
}
|
|
|
|
|
|
-template <typename T, int D>
|
|
|
+template <typename T, int D, int warp_size>
|
|
|
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
|
|
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
|
|
|
|
|
|
const half2 * K_h2 = (const half2 *) K_c;
|
|
|
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
|
|
|
GGML_UNUSED(Q_q8);
|
|
|
GGML_UNUSED(Q_ds_v);
|
|
|
|
|
|
@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
|
|
|
return x[i];
|
|
|
}
|
|
|
|
|
|
-template <int D>
|
|
|
+template <int D, int warp_size = WARP_SIZE>
|
|
|
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
|
|
|
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
|
|
|
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
|
|
|
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
|
|
|
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
|
|
|
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
|
|
|
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
|
|
|
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
|
|
|
nullptr;
|
|
|
}
|
|
|
|
|
|
-template <int D>
|
|
|
+template <int D, int warp_size = WARP_SIZE>
|
|
|
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
|
|
|
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
|
|
|
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
|
|
|
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
|
|
|
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
|
|
|
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
|
|
|
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
|
|
|
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
|
|
|
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
|
|
|
nullptr;
|
|
|
}
|
|
|
|
|
|
@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
|
|
|
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
|
|
void launch_fattn(
|
|
|
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
|
|
- const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
|
|
+ const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
|
|
|
+ const int warp_size = WARP_SIZE
|
|
|
) {
|
|
|
constexpr int ncols = ncols1 * ncols2;
|
|
|
|
|
|
@@ -704,8 +699,6 @@ void launch_fattn(
|
|
|
|
|
|
GGML_ASSERT(Q->ne[3] == 1);
|
|
|
|
|
|
- const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
|
|
|
-
|
|
|
ggml_cuda_pool & pool = ctx.pool();
|
|
|
cudaStream_t main_stream = ctx.stream();
|
|
|
const int id = ggml_cuda_get_device();
|
|
|
@@ -805,7 +798,6 @@ void launch_fattn(
|
|
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
|
|
|
|
GGML_ASSERT(block_dim.x % warp_size == 0);
|
|
|
- GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
|
|
|
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
|
|
|
(const char *) Q->data,
|
|
|
K_data,
|