|
@@ -428,7 +428,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
char * src0_ddc = (char *) src0->data;
|
|
char * src0_ddc = (char *) src0->data;
|
|
|
char * src1_ddc = (char *) src1->data;
|
|
char * src1_ddc = (char *) src1->data;
|
|
|
|
|
|
|
|
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
|
|
|
|
|
+ if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
|
|
|
|
+ GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
|
|
|
|
|
+ CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream));
|
|
|
|
|
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
|
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
|
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
@@ -449,9 +452,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
|
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
|
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
|
|
|
} else {
|
|
} else {
|
|
|
- fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
|
|
|
|
|
|
+ GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
|
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
|
- GGML_ABORT("fatal error");
|
|
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -461,29 +463,30 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
|
|
|
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
|
|
|
- return (void*) cpy_f32_f16<cpy_1_f32_f32>;
|
|
|
|
|
|
|
+ if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
|
|
|
|
|
+ return nullptr;
|
|
|
|
|
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
|
|
|
|
|
+ return (void*) cpy_f32_f16<cpy_1_f32_f32>;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
|
|
|
- return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
|
|
|
|
|
|
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
|
|
|
- return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
|
|
|
|
|
|
+ return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
|
|
|
- return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
|
|
|
|
|
|
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
|
|
|
- return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
|
|
|
|
|
|
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
|
|
|
- return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
|
|
|
|
|
|
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
|
|
|
- return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
|
|
|
|
|
|
+ return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
|
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
|
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
|
|
|
- return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
|
|
|
|
|
|
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
|
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
|
|
|
- return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
|
|
|
|
|
|
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
|
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
|
|
|
- return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
|
|
|
|
|
|
+ return (void*) cpy_f32_f16<cpy_1_f16_f32>;
|
|
|
} else {
|
|
} else {
|
|
|
- fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
|
|
|
|
|
|
|
+ GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
|
|
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
ggml_type_name(src0->type), ggml_type_name(src1->type));
|
|
|
- GGML_ABORT("fatal error");
|
|
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|