|
|
@@ -10,6 +10,13 @@ static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
|
|
*dsti = *xi;
|
|
|
}
|
|
|
|
|
|
+static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) {
|
|
|
+ const float * xi = (const float *) cxi;
|
|
|
+ nv_bfloat16 * dsti = (nv_bfloat16 *) cdsti;
|
|
|
+
|
|
|
+ *dsti = *xi;
|
|
|
+}
|
|
|
+
|
|
|
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
|
|
const float * xi = (const float *) cxi;
|
|
|
half * dsti = (half *) cdsti;
|
|
|
@@ -386,6 +393,16 @@ static void ggml_cpy_f32_f32_cuda(
|
|
|
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
|
|
}
|
|
|
|
|
|
+static void ggml_cpy_f32_bf16_cuda(
|
|
|
+ const char * cx, char * cdst, const int ne,
|
|
|
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) {
|
|
|
+
|
|
|
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
|
|
|
+ cpy_f32_f16<cpy_1_f32_bf16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
|
|
|
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++);
|
|
|
+}
|
|
|
+
|
|
|
static void ggml_cpy_f32_f16_cuda(
|
|
|
const char * cx, char * cdst, const int ne,
|
|
|
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
|
|
|
@@ -581,6 +598,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
|
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
|
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
|
+ ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
|
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
|
@@ -634,6 +653,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|
|
return nullptr;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
|
return (void*) cpy_f32_f16<cpy_1_f32_f32>;
|
|
|
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
|
|
|
+ return (void*) cpy_f32_f16<cpy_1_f32_bf16>;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
|
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|