|
|
@@ -1,9 +1,9 @@
|
|
|
#include "ggml.h"
|
|
|
#include "common.cuh"
|
|
|
-#include "mmv.cuh"
|
|
|
+#include "mmvf.cuh"
|
|
|
|
|
|
template <typename T, typename type_acc, int ncols_dst, int block_size>
|
|
|
-static __global__ void mul_mat_vec(
|
|
|
+static __global__ void mul_mat_vec_f(
|
|
|
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
|
|
|
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
|
|
|
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
|
|
|
@@ -37,7 +37,7 @@ static __global__ void mul_mat_vec(
|
|
|
|
|
|
float sumf[ncols_dst] = {0.0f};
|
|
|
|
|
|
- if constexpr (std::is_same<T, float>::value) {
|
|
|
+ if constexpr (std::is_same_v<T, float>) {
|
|
|
const float2 * x2 = (const float2 *) x;
|
|
|
|
|
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
|
@@ -50,10 +50,10 @@ static __global__ void mul_mat_vec(
|
|
|
sumf[j] += tmpx.y*tmpy.y;
|
|
|
}
|
|
|
}
|
|
|
- } else if constexpr (std::is_same<T, half>::value) {
|
|
|
+ } else if constexpr (std::is_same_v<T, half>) {
|
|
|
const half2 * x2 = (const half2 *) x;
|
|
|
|
|
|
- if (std::is_same<type_acc, float>::value) {
|
|
|
+ if (std::is_same_v<type_acc, float>) {
|
|
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
|
const float2 tmpx = __half22float2(x2[col2]);
|
|
|
|
|
|
@@ -86,7 +86,7 @@ static __global__ void mul_mat_vec(
|
|
|
NO_DEVICE_CODE;
|
|
|
#endif // FP16_AVAILABLE
|
|
|
}
|
|
|
- } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
|
|
|
+ } else if constexpr (std::is_same_v<T, nv_bfloat16>) {
|
|
|
const int * x2 = (const int *) x;
|
|
|
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
|
|
|
const int tmpx = x2[col2];
|
|
|
@@ -98,7 +98,7 @@ static __global__ void mul_mat_vec(
|
|
|
}
|
|
|
}
|
|
|
} else {
|
|
|
- static_assert(std::is_same<T, void>::value, "unsupported type");
|
|
|
+ static_assert(std::is_same_v<T, void>, "unsupported type");
|
|
|
}
|
|
|
|
|
|
#pragma unroll
|
|
|
@@ -126,7 +126,7 @@ static __global__ void mul_mat_vec(
|
|
|
}
|
|
|
|
|
|
template <typename T, typename type_acc, int ncols_dst>
|
|
|
-static void launch_mul_mat_vec_cuda(
|
|
|
+static void launch_mul_mat_vec_f_cuda(
|
|
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
|
const int64_t ncols, const int64_t nrows,
|
|
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
|
@@ -141,11 +141,9 @@ static void launch_mul_mat_vec_cuda(
|
|
|
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
|
|
|
const int64_t channel_ratio = nchannels_dst / nchannels_x;
|
|
|
const int64_t sample_ratio = nsamples_dst / nsamples_x;
|
|
|
- int device;
|
|
|
- int warp_size;
|
|
|
|
|
|
- CUDA_CHECK(cudaGetDevice(&device));
|
|
|
- warp_size = ggml_cuda_info().devices[device].warp_size;
|
|
|
+ const int device = ggml_cuda_get_device();
|
|
|
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
|
|
|
|
|
|
int64_t block_size_best = warp_size;
|
|
|
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
|
|
|
@@ -161,54 +159,54 @@ static void launch_mul_mat_vec_cuda(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- const int smem = warp_size*sizeof(float);
|
|
|
+ const int nbytes_shared = warp_size*sizeof(float);
|
|
|
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
|
|
|
const dim3 block_dims(block_size_best, 1, 1);
|
|
|
switch (block_size_best) {
|
|
|
case 32: {
|
|
|
- mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
} break;
|
|
|
case 64: {
|
|
|
- mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
} break;
|
|
|
case 96: {
|
|
|
- mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
} break;
|
|
|
case 128: {
|
|
|
- mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
} break;
|
|
|
case 160: {
|
|
|
- mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
} break;
|
|
|
case 192: {
|
|
|
- mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
} break;
|
|
|
case 224: {
|
|
|
- mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
} break;
|
|
|
case 256: {
|
|
|
- mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
|
|
|
+ mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
|
|
|
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
|
|
|
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
|
|
|
@@ -220,7 +218,7 @@ static void launch_mul_mat_vec_cuda(
|
|
|
}
|
|
|
|
|
|
template <typename T, typename type_acc>
|
|
|
-static void mul_mat_vec_cuda_switch_ncols_dst(
|
|
|
+static void mul_mat_vec_f_cuda_switch_ncols_dst(
|
|
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
|
|
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
|
|
|
@@ -230,49 +228,49 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
|
|
|
cudaStream_t stream) {
|
|
|
switch (ncols_dst) {
|
|
|
case 1:
|
|
|
- launch_mul_mat_vec_cuda<T, type_acc, 1>
|
|
|
+ launch_mul_mat_vec_f_cuda<T, type_acc, 1>
|
|
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
break;
|
|
|
case 2:
|
|
|
- launch_mul_mat_vec_cuda<T, type_acc, 2>
|
|
|
+ launch_mul_mat_vec_f_cuda<T, type_acc, 2>
|
|
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
break;
|
|
|
case 3:
|
|
|
- launch_mul_mat_vec_cuda<T, type_acc, 3>
|
|
|
+ launch_mul_mat_vec_f_cuda<T, type_acc, 3>
|
|
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
break;
|
|
|
case 4:
|
|
|
- launch_mul_mat_vec_cuda<T, type_acc, 4>
|
|
|
+ launch_mul_mat_vec_f_cuda<T, type_acc, 4>
|
|
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
break;
|
|
|
case 5:
|
|
|
- launch_mul_mat_vec_cuda<T, type_acc, 5>
|
|
|
+ launch_mul_mat_vec_f_cuda<T, type_acc, 5>
|
|
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
break;
|
|
|
case 6:
|
|
|
- launch_mul_mat_vec_cuda<T, type_acc, 6>
|
|
|
+ launch_mul_mat_vec_f_cuda<T, type_acc, 6>
|
|
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
break;
|
|
|
case 7:
|
|
|
- launch_mul_mat_vec_cuda<T, type_acc, 7>
|
|
|
+ launch_mul_mat_vec_f_cuda<T, type_acc, 7>
|
|
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
break;
|
|
|
case 8:
|
|
|
- launch_mul_mat_vec_cuda<T, type_acc, 8>
|
|
|
+ launch_mul_mat_vec_f_cuda<T, type_acc, 8>
|
|
|
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
@@ -284,7 +282,7 @@ static void mul_mat_vec_cuda_switch_ncols_dst(
|
|
|
}
|
|
|
|
|
|
template<typename T>
|
|
|
-static void mul_mat_vec_cuda(
|
|
|
+static void mul_mat_vec_f_cuda(
|
|
|
const T * x, const float * y, const int32_t * ids, float * dst,
|
|
|
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
|
|
|
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
|
|
|
@@ -292,22 +290,22 @@ static void mul_mat_vec_cuda(
|
|
|
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
|
|
|
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
|
|
|
enum ggml_prec prec, cudaStream_t stream) {
|
|
|
- if constexpr(std::is_same<T, half>::value) {
|
|
|
+ if constexpr(std::is_same_v<T, half>) {
|
|
|
if (prec == GGML_PREC_DEFAULT) {
|
|
|
- mul_mat_vec_cuda_switch_ncols_dst<T, half>
|
|
|
+ mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
|
|
|
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
return;
|
|
|
}
|
|
|
}
|
|
|
- mul_mat_vec_cuda_switch_ncols_dst<T, float>
|
|
|
+ mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
|
|
|
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
|
|
|
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
|
|
|
}
|
|
|
|
|
|
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
|
|
+void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
|
|
|
GGML_ASSERT( src1->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
|
|
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
|
|
@@ -355,19 +353,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
|
switch (src0->type) {
|
|
|
case GGML_TYPE_F32: {
|
|
|
const float * src0_d = (const float *) src0->data;
|
|
|
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
|
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
|
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
|
|
} break;
|
|
|
case GGML_TYPE_F16: {
|
|
|
const half * src0_d = (const half *) src0->data;
|
|
|
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
|
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
|
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
|
|
} break;
|
|
|
case GGML_TYPE_BF16: {
|
|
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
|
|
|
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
|
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
|
|
|
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
|
|
|
ne03, ne3, s03, s13, s3, prec, ctx.stream());
|
|
|
} break;
|
|
|
@@ -376,7 +374,7 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void ggml_cuda_op_mul_mat_vec(
|
|
|
+void ggml_cuda_op_mul_mat_vec_f(
|
|
|
ggml_backend_cuda_context & ctx,
|
|
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
|
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
|
|
@@ -414,19 +412,19 @@ void ggml_cuda_op_mul_mat_vec(
|
|
|
switch (src0->type) {
|
|
|
case GGML_TYPE_F32: {
|
|
|
const float * src0_d = (const float *) src0_dd_i;
|
|
|
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
|
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
|
|
} break;
|
|
|
case GGML_TYPE_F16: {
|
|
|
const half * src0_d = (const half *) src0_dd_i;
|
|
|
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
|
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
|
|
} break;
|
|
|
case GGML_TYPE_BF16: {
|
|
|
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
|
|
|
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
|
+ mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
|
|
|
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
|
|
|
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
|
|
|
} break;
|
|
|
@@ -442,15 +440,15 @@ void ggml_cuda_op_mul_mat_vec(
|
|
|
GGML_UNUSED(src1_padded_row_size);
|
|
|
}
|
|
|
|
|
|
-bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
|
|
|
+bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
|
|
|
if (src0_ne[0] % 2 != 0) {
|
|
|
return false;
|
|
|
}
|
|
|
switch (type) {
|
|
|
case GGML_TYPE_F32:
|
|
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
|
|
- if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
|
|
- return ne11 <= 8;
|
|
|
+ if (ampere_mma_available(cc)) {
|
|
|
+ return ne11 <= 3;
|
|
|
}
|
|
|
if (cc >= GGML_CUDA_CC_TURING) {
|
|
|
return ne11 <= 4;
|
|
|
@@ -466,6 +464,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
|
|
|
case GGML_TYPE_F16:
|
|
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
|
|
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
|
|
+ if (ampere_mma_available(cc)) {
|
|
|
+ return src0_small && ne11 == 1;
|
|
|
+ }
|
|
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
|
|
return src0_small && ne11 <= 4;
|
|
|
}
|
|
|
@@ -486,6 +487,9 @@ bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_
|
|
|
case GGML_TYPE_BF16:
|
|
|
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
|
|
|
const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
|
|
|
+ if (ampere_mma_available(cc)) {
|
|
|
+ return src0_small && ne11 == 1;
|
|
|
+ }
|
|
|
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
|
|
|
return src0_small && ne11 <= 4;
|
|
|
}
|