|
@@ -7330,6 +7330,7 @@ static void ggml_compute_forward_group_norm(
|
|
|
static void ggml_compute_forward_mul_mat_one_chunk(
|
|
static void ggml_compute_forward_mul_mat_one_chunk(
|
|
|
const struct ggml_compute_params * params,
|
|
const struct ggml_compute_params * params,
|
|
|
struct ggml_tensor * dst,
|
|
struct ggml_tensor * dst,
|
|
|
|
|
+ const enum ggml_type type,
|
|
|
const int64_t num_rows_per_vec_dot,
|
|
const int64_t num_rows_per_vec_dot,
|
|
|
const int64_t ir0_start,
|
|
const int64_t ir0_start,
|
|
|
const int64_t ir0_end,
|
|
const int64_t ir0_end,
|
|
@@ -7341,8 +7342,6 @@ static void ggml_compute_forward_mul_mat_one_chunk(
|
|
|
|
|
|
|
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
GGML_TENSOR_BINARY_OP_LOCALS
|
|
|
|
|
|
|
|
- const enum ggml_type type = src0->type;
|
|
|
|
|
-
|
|
|
|
|
const bool src1_cont = ggml_is_contiguous(src1);
|
|
const bool src1_cont = ggml_is_contiguous(src1);
|
|
|
|
|
|
|
|
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
|
ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot;
|
|
@@ -7430,7 +7429,11 @@ static void ggml_compute_forward_mul_mat(
|
|
|
const int ith = params->ith;
|
|
const int ith = params->ith;
|
|
|
const int nth = params->nth;
|
|
const int nth = params->nth;
|
|
|
|
|
|
|
|
- const enum ggml_type type = src0->type;
|
|
|
|
|
|
|
+ enum ggml_type type = src0->type;
|
|
|
|
|
+
|
|
|
|
|
+ if (src0->buffer && ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) {
|
|
|
|
|
+ type = (enum ggml_type)(intptr_t)src0->extra;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
|
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
|
|
|
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
|
|
ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float;
|
|
@@ -7469,15 +7472,15 @@ static void ggml_compute_forward_mul_mat(
|
|
|
if (src1_cont) {
|
|
if (src1_cont) {
|
|
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
|
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
|
|
|
|
|
|
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
|
|
|
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
|
- nb01/ggml_type_size(src0->type),
|
|
|
|
|
|
|
+ nb01/ggml_type_size(type),
|
|
|
(const char *)src1->data + i12*nb12 + i13*nb13,
|
|
(const char *)src1->data + i12*nb12 + i13*nb13,
|
|
|
nb11/ggml_type_size(src1->type),
|
|
nb11/ggml_type_size(src1->type),
|
|
|
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
|
nb1/ggml_type_size(dst->type),
|
|
nb1/ggml_type_size(dst->type),
|
|
|
ith, nth,
|
|
ith, nth,
|
|
|
- src0->type,
|
|
|
|
|
|
|
+ type,
|
|
|
src1->type,
|
|
src1->type,
|
|
|
dst->type))
|
|
dst->type))
|
|
|
goto UseGgmlGemm1;
|
|
goto UseGgmlGemm1;
|
|
@@ -7530,15 +7533,15 @@ UseGgmlGemm1:;
|
|
|
|
|
|
|
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
for (int64_t i13 = 0; i13 < ne13; i13++)
|
|
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
for (int64_t i12 = 0; i12 < ne12; i12++)
|
|
|
- if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
|
|
|
|
|
|
+ if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(type),
|
|
|
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
|
|
|
- nb01/ggml_type_size(src0->type),
|
|
|
|
|
|
|
+ nb01/ggml_type_size(type),
|
|
|
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
|
|
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
|
|
|
row_size/ggml_type_size(vec_dot_type),
|
|
row_size/ggml_type_size(vec_dot_type),
|
|
|
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
(char *)dst->data + i12*nb2 + i13*nb3,
|
|
|
nb1/ggml_type_size(dst->type),
|
|
nb1/ggml_type_size(dst->type),
|
|
|
ith, nth,
|
|
ith, nth,
|
|
|
- src0->type,
|
|
|
|
|
|
|
+ type,
|
|
|
vec_dot_type,
|
|
vec_dot_type,
|
|
|
dst->type))
|
|
dst->type))
|
|
|
goto UseGgmlGemm2;
|
|
goto UseGgmlGemm2;
|
|
@@ -7623,7 +7626,7 @@ UseGgmlGemm2:;
|
|
|
const int64_t ir1_start = dr1 * ith1;
|
|
const int64_t ir1_start = dr1 * ith1;
|
|
|
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
|
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
|
|
|
|
|
|
|
- ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
|
|
|
|
|
|
+ ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
|
|
|
|
|
|
|
if (nth >= nchunk0 * nchunk1) {
|
|
if (nth >= nchunk0 * nchunk1) {
|
|
|
break;
|
|
break;
|