|
@@ -1114,8 +1114,8 @@ static void ggml_cuda_op_mul_mat_cublas(
|
|
|
CUBLAS_CHECK(
|
|
CUBLAS_CHECK(
|
|
|
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
|
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
|
|
row_diff, src1_ncols, ne10,
|
|
row_diff, src1_ncols, ne10,
|
|
|
- &alpha, src0_ptr, CUDA_R_16F, ne00,
|
|
|
|
|
- src1_ptr, CUDA_R_16F, ne10,
|
|
|
|
|
|
|
+ &alpha, src0_ptr, CUDA_R_16F, ne00,
|
|
|
|
|
+ src1_ptr, CUDA_R_16F, ne10,
|
|
|
&beta, dst_dd_i, CUDA_R_32F, ldc,
|
|
&beta, dst_dd_i, CUDA_R_32F, ldc,
|
|
|
CUBLAS_COMPUTE_32F,
|
|
CUBLAS_COMPUTE_32F,
|
|
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
@@ -1128,9 +1128,9 @@ static void ggml_cuda_op_mul_mat_cublas(
|
|
|
CUBLAS_CHECK(
|
|
CUBLAS_CHECK(
|
|
|
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
|
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
|
|
|
row_diff, src1_ncols, ne10,
|
|
row_diff, src1_ncols, ne10,
|
|
|
- &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
|
|
|
|
- src1_ptr, CUDA_R_16F, ne10,
|
|
|
|
|
- &beta_f16, dst_dd_i, CUDA_R_16F, ldc,
|
|
|
|
|
|
|
+ &alpha_f16, src0_ptr, CUDA_R_16F, ne00,
|
|
|
|
|
+ src1_ptr, CUDA_R_16F, ne10,
|
|
|
|
|
+ &beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
|
|
|
CUBLAS_COMPUTE_16F,
|
|
CUBLAS_COMPUTE_16F,
|
|
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
|
|
|
|
|