|
@@ -8322,8 +8322,7 @@ static void ggml_compute_forward_dup_same_cont(
|
|
|
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
|
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
- const size_t nb00 = src0->nb[0];
|
|
|
|
|
- const size_t nb0 = dst->nb[0];
|
|
|
|
|
|
|
+ const size_t nb0 = ggml_type_size(src0->type);
|
|
|
|
|
|
|
|
const int ith = params->ith; // thread index
|
|
const int ith = params->ith; // thread index
|
|
|
const int nth = params->nth; // number of threads
|
|
const int nth = params->nth; // number of threads
|
|
@@ -8337,8 +8336,8 @@ static void ggml_compute_forward_dup_same_cont(
|
|
|
if (ie0 < ie1) {
|
|
if (ie0 < ie1) {
|
|
|
memcpy(
|
|
memcpy(
|
|
|
((char *) dst->data + ie0*nb0),
|
|
((char *) dst->data + ie0*nb0),
|
|
|
- ((char *) src0->data + ie0*nb00),
|
|
|
|
|
- (ie1 - ie0) * ggml_type_size(src0->type));
|
|
|
|
|
|
|
+ ((char *) src0->data + ie0*nb0),
|
|
|
|
|
+ (ie1 - ie0) * nb0);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -8355,11 +8354,6 @@ static void ggml_compute_forward_dup_f16(
|
|
|
const int ith = params->ith; // thread index
|
|
const int ith = params->ith; // thread index
|
|
|
const int nth = params->nth; // number of threads
|
|
const int nth = params->nth; // number of threads
|
|
|
|
|
|
|
|
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
|
|
|
|
- ggml_compute_forward_dup_same_cont(params, dst);
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
// parallelize by rows
|
|
// parallelize by rows
|
|
|
const int nr = ne01;
|
|
const int nr = ne01;
|
|
|
// number of rows per thread
|
|
// number of rows per thread
|
|
@@ -8624,11 +8618,6 @@ static void ggml_compute_forward_dup_bf16(
|
|
|
const int ith = params->ith; // thread index
|
|
const int ith = params->ith; // thread index
|
|
|
const int nth = params->nth; // number of threads
|
|
const int nth = params->nth; // number of threads
|
|
|
|
|
|
|
|
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
|
|
|
|
- ggml_compute_forward_dup_same_cont(params, dst);
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
// parallelize by rows
|
|
// parallelize by rows
|
|
|
const int nr = ne01;
|
|
const int nr = ne01;
|
|
|
// number of rows per thread
|
|
// number of rows per thread
|
|
@@ -8980,11 +8969,6 @@ static void ggml_compute_forward_dup_f32(
|
|
|
const int ith = params->ith; // thread index
|
|
const int ith = params->ith; // thread index
|
|
|
const int nth = params->nth; // number of threads
|
|
const int nth = params->nth; // number of threads
|
|
|
|
|
|
|
|
- if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst) && src0->type == dst->type) {
|
|
|
|
|
- ggml_compute_forward_dup_same_cont(params, dst);
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
// parallelize by rows
|
|
// parallelize by rows
|
|
|
const int nr = ne01;
|
|
const int nr = ne01;
|
|
|
// number of rows per thread
|
|
// number of rows per thread
|
|
@@ -9294,13 +9278,13 @@ static void ggml_compute_forward_dup_bytes(
|
|
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
|
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
GGML_ASSERT(src0->type == dst->type);
|
|
|
|
|
|
|
|
|
|
+ GGML_TENSOR_UNARY_OP_LOCALS;
|
|
|
|
|
+
|
|
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
|
|
if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
|
|
|
ggml_compute_forward_dup_same_cont(params, dst);
|
|
ggml_compute_forward_dup_same_cont(params, dst);
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- GGML_TENSOR_UNARY_OP_LOCALS;
|
|
|
|
|
-
|
|
|
|
|
const size_t type_size = ggml_type_size(src0->type);
|
|
const size_t type_size = ggml_type_size(src0->type);
|
|
|
const int ith = params->ith; // thread index
|
|
const int ith = params->ith; // thread index
|
|
|
const int nth = params->nth; // number of threads
|
|
const int nth = params->nth; // number of threads
|