|
|
@@ -1852,6 +1852,9 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
|
|
ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
|
|
|
ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
|
|
|
|
|
|
+ bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
|
|
|
+ bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
|
|
|
+
|
|
|
// Handle src0
|
|
|
src0_ptr = (const cuda_t *) src0->data;
|
|
|
|
|
|
@@ -1870,6 +1873,8 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
|
|
s11 = ne10;
|
|
|
s12 = ne11*s11;
|
|
|
s13 = ne12*s12;
|
|
|
+
|
|
|
+ is_src1_cont_2 = true;
|
|
|
}
|
|
|
|
|
|
// Setup destination buffer
|
|
|
@@ -1918,15 +1923,19 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
|
|
const int64_t r2 = ne12/ne02;
|
|
|
const int64_t r3 = ne13/ne03;
|
|
|
|
|
|
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
|
|
+ if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
|
|
|
+ // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
|
|
|
+ const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
|
|
|
+ const int64_t smb = ne12 == 1 ? s13 : s12;
|
|
|
+
|
|
|
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
|
|
// use cublasGemmStridedBatchedEx
|
|
|
CUBLAS_CHECK(
|
|
|
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
|
|
ne01, ne11, ne10,
|
|
|
- alpha, src0_ptr, cu_data_type_a, nb01/nb00, nb02/nb00, // strideA
|
|
|
- src1_ptr, cu_data_type_b, s11, s12, // strideB
|
|
|
- beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
|
|
|
+ alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
|
|
|
+ src1_ptr, cu_data_type_b, s11, smb, // strideB
|
|
|
+ beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
|
|
|
ne12*ne13,
|
|
|
cu_compute_type,
|
|
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|