|
@@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
|
|
|
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
|
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
|
|
|
|
|
|
|
|
template <cpy_kernel_t cpy_1>
|
|
template <cpy_kernel_t cpy_1>
|
|
|
-static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
|
|
|
|
- const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
|
|
- const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
|
|
|
|
- const int nb12, const int nb13) {
|
|
|
|
|
|
|
+static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
|
|
|
|
|
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
|
|
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
|
|
|
|
+ const int nb12, const int nb13) {
|
|
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
if (i >= ne) {
|
|
if (i >= ne) {
|
|
@@ -40,7 +40,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
template <typename T>
|
|
|
-static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
|
|
|
|
|
|
|
+static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
|
|
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
|
|
|
const int nb12, const int nb13) {
|
|
const int nb12, const int nb13) {
|
|
@@ -166,7 +166,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template<typename src_t, typename dst_t>
|
|
template<typename src_t, typename dst_t>
|
|
|
-static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
|
|
|
|
|
|
+static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
|
|
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
if (i >= ne) {
|
|
if (i >= ne) {
|
|
@@ -180,17 +180,17 @@ static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const in
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template<typename src_t, typename dst_t>
|
|
template<typename src_t, typename dst_t>
|
|
|
-static void ggml_cpy_flt_contiguous_cuda(
|
|
|
|
|
|
|
+static void ggml_cpy_scalar_contiguous_cuda(
|
|
|
const char * cx, char * cdst, const int64_t ne,
|
|
const char * cx, char * cdst, const int64_t ne,
|
|
|
cudaStream_t stream) {
|
|
cudaStream_t stream) {
|
|
|
|
|
|
|
|
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
|
- cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
|
|
|
|
|
+ cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
|
(cx, cdst, ne);
|
|
(cx, cdst, ne);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template<typename src_t, typename dst_t, bool transposed = false>
|
|
template<typename src_t, typename dst_t, bool transposed = false>
|
|
|
-static void ggml_cpy_flt_cuda(
|
|
|
|
|
|
|
+static void ggml_cpy_scalar_cuda(
|
|
|
const char * cx, char * cdst, const int ne,
|
|
const char * cx, char * cdst, const int ne,
|
|
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
|
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
|
|
@@ -212,11 +212,11 @@ static void ggml_cpy_flt_cuda(
|
|
|
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
|
(ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
|
|
|
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
|
|
(ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
|
|
|
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
|
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
|
|
|
- cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
|
|
|
|
|
|
+ cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
|
|
|
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
|
} else {
|
|
} else {
|
|
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
|
- cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
|
|
|
|
|
+ cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -399,94 +399,132 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
|
if (can_be_transposed) {
|
|
if (can_be_transposed) {
|
|
|
- ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<float, float, true>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<float, float>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
|
if (contiguous_srcs) {
|
|
if (contiguous_srcs) {
|
|
|
- ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<float, nv_bfloat16>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
|
if (contiguous_srcs) {
|
|
if (contiguous_srcs) {
|
|
|
- ggml_cpy_flt_contiguous_cuda<float, half> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_contiguous_cuda<float, half>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<float, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<float, half>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
|
- ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_f32_q8_0_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
|
|
|
- ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_q8_0_f32_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
|
|
- ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_f32_q4_0_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
|
|
|
- ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
|
|
|
|
- nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_q4_0_f32_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
|
- ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_f32_q4_1_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
|
|
|
- ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
|
|
|
|
- nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_q4_1_f32_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
|
|
- ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_f32_q5_0_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
|
|
|
- ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
|
|
|
|
|
- nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_q5_0_f32_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
|
|
- ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_f32_iq4_nl_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
|
|
- ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_f32_q5_1_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
|
|
|
- ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_q5_1_f32_cuda
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
|
if (can_be_transposed) {
|
|
if (can_be_transposed) {
|
|
|
- ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<half, half, true>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<half, half>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
|
|
|
if (contiguous_srcs) {
|
|
if (contiguous_srcs) {
|
|
|
- ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<half, nv_bfloat16>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
|
if (contiguous_srcs) {
|
|
if (contiguous_srcs) {
|
|
|
- ggml_cpy_flt_contiguous_cuda<half, float> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_contiguous_cuda<half, float>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<half, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<half, float>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
|
|
|
if (can_be_transposed) {
|
|
if (can_be_transposed) {
|
|
|
- ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
|
|
|
if (contiguous_srcs) {
|
|
if (contiguous_srcs) {
|
|
|
- ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<nv_bfloat16, half>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
|
|
|
if (contiguous_srcs) {
|
|
if (contiguous_srcs) {
|
|
|
- ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ ggml_cpy_scalar_cuda<nv_bfloat16, float>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
+ }
|
|
|
|
|
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
|
|
|
|
|
+ if (can_be_transposed) {
|
|
|
|
|
+ ggml_cpy_scalar_cuda<int32_t, int32_t, true>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<int32_t, int32_t>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
|
|
|
if (contiguous_srcs) {
|
|
if (contiguous_srcs) {
|
|
|
- ggml_cpy_flt_contiguous_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_contiguous_cuda<float, int32_t>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<float, int32_t>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
|
|
|
if (contiguous_srcs) {
|
|
if (contiguous_srcs) {
|
|
|
- ggml_cpy_flt_contiguous_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_contiguous_cuda<int32_t, float>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
|
|
|
|
+ ggml_cpy_scalar_cuda<int32_t, float>
|
|
|
|
|
+ (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
|
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
|
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|