|
|
@@ -152,16 +152,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
|
|
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
|
|
|
|
|
|
int64_t total_vram = 0;
|
|
|
-#if defined(GGML_CUDA_FORCE_MMQ)
|
|
|
- GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
|
|
|
+#ifdef GGML_CUDA_FORCE_MMQ
|
|
|
+ GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
|
|
|
#else
|
|
|
- GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
|
|
|
-#endif
|
|
|
-#if defined(CUDA_USE_TENSOR_CORES)
|
|
|
- GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
|
|
|
+ GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
|
|
|
+#endif // GGML_CUDA_FORCE_MMQ
|
|
|
+#ifdef GGML_CUDA_FORCE_CUBLAS
|
|
|
+ GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
|
|
|
#else
|
|
|
- GGML_CUDA_LOG_INFO("%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
|
|
|
-#endif
|
|
|
+ GGML_CUDA_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
|
|
|
+#endif // GGML_CUDA_FORCE_CUBLAS
|
|
|
GGML_CUDA_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
|
|
|
for (int id = 0; id < info.device_count; ++id) {
|
|
|
int device_vmm = 0;
|
|
|
@@ -1873,9 +1873,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
|
|
static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer);
|
|
|
|
|
|
- int64_t min_compute_capability = INT_MAX;
|
|
|
+ bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
|
|
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
|
+ && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
|
|
|
+ bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
|
|
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
|
+ && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
|
|
+ bool use_mul_mat_q = ggml_is_quantized(src0->type)
|
|
|
+ && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
|
|
+
|
|
|
+ bool any_gpus_with_slow_fp16 = false;
|
|
|
|
|
|
- bool any_pascal_with_slow_fp16 = false;
|
|
|
if (split) {
|
|
|
ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
|
|
|
auto & tensor_split = buft_ctx->tensor_split;
|
|
|
@@ -1885,55 +1893,18 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
|
continue;
|
|
|
}
|
|
|
|
|
|
- if (min_compute_capability > ggml_cuda_info().devices[id].cc) {
|
|
|
- min_compute_capability = ggml_cuda_info().devices[id].cc;
|
|
|
- }
|
|
|
- if (ggml_cuda_info().devices[id].cc == 610) {
|
|
|
- any_pascal_with_slow_fp16 = true;
|
|
|
- }
|
|
|
+ const int cc = ggml_cuda_info().devices[id].cc;
|
|
|
+ use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A;
|
|
|
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
|
|
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
|
|
}
|
|
|
} else {
|
|
|
- min_compute_capability = ggml_cuda_info().devices[ctx.device].cc;
|
|
|
- any_pascal_with_slow_fp16 = ggml_cuda_info().devices[ctx.device].cc == 610;
|
|
|
+ const int cc = ggml_cuda_info().devices[ctx.device].cc;
|
|
|
+ use_mul_mat_vec_q = use_mul_mat_vec_q && cc >= MIN_CC_DP4A;
|
|
|
+ use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
|
|
|
+ any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc);
|
|
|
}
|
|
|
|
|
|
- // check data types and tensor shapes for custom matrix multiplication kernels:
|
|
|
- bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16)
|
|
|
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
|
- && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src1->ne[1] == 1;
|
|
|
-
|
|
|
- bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
|
|
|
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
|
|
|
- && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
|
|
|
-
|
|
|
- bool use_mul_mat_q = ggml_cuda_supports_mmq(src0->type)
|
|
|
- && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
|
|
|
-
|
|
|
-#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
|
-
|
|
|
- const bool fp16_performance_good = min_compute_capability >= CC_RDNA1;
|
|
|
-
|
|
|
-#ifdef CUDA_USE_TENSOR_CORES
|
|
|
- use_mul_mat_q = use_mul_mat_q && min_compute_capability < CC_RDNA3;
|
|
|
-#endif // CUDA_USE_TENSOR_CORES
|
|
|
-
|
|
|
-#else
|
|
|
-
|
|
|
- // fp16 performance is good on Volta or newer and on P100 (compute capability 6.0)
|
|
|
- const bool fp16_performance_good = min_compute_capability >= CC_PASCAL && !any_pascal_with_slow_fp16;
|
|
|
-
|
|
|
- // mmvq and mmq need the __dp4a instruction which on NVIDIA is only available for CC >= 6.1
|
|
|
- use_mul_mat_vec_q = use_mul_mat_vec_q && min_compute_capability >= MIN_CC_DP4A;
|
|
|
- use_mul_mat_q = use_mul_mat_q && min_compute_capability >= MIN_CC_DP4A;
|
|
|
-
|
|
|
-#ifdef CUDA_USE_TENSOR_CORES
|
|
|
- // when tensor cores are available, use them for large batch size
|
|
|
- // ref: https://github.com/ggerganov/llama.cpp/pull/3776
|
|
|
- use_mul_mat_q = use_mul_mat_q && (!fp16_performance_good || src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
|
|
|
-#endif // CUDA_USE_TENSOR_CORES
|
|
|
-
|
|
|
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
|
|
-
|
|
|
// if mmvq is available it's a better choice than dmmv:
|
|
|
#ifndef GGML_CUDA_FORCE_DMMV
|
|
|
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
|
|
|
@@ -1947,21 +1918,22 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
|
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
|
|
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
|
|
|
|
|
- if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
|
|
- // KQ single-batch
|
|
|
+ if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
|
|
+ // FP32 precision KQ single-batch for batch size 1 without FlashAttention
|
|
|
ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst);
|
|
|
- } else if (!split && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
|
|
- // KQV single-batch
|
|
|
+ } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
|
|
+ // FP32 precision KQV single-batch for batch size 1 without FlashAttention
|
|
|
ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst);
|
|
|
- } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || fp16_performance_good) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
|
|
- // KQ + KQV multi-batch
|
|
|
- ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
|
|
} else if (use_dequantize_mul_mat_vec) {
|
|
|
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
|
|
|
} else if (use_mul_mat_vec_q) {
|
|
|
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
|
|
|
} else if (use_mul_mat_q) {
|
|
|
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
|
|
|
+ } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16)
|
|
|
+ && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
|
|
+ // KQ + KQV multi-batch without FlashAttention
|
|
|
+ ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
|
|
|
} else {
|
|
|
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
|
|
|
}
|