|
|
@@ -1540,7 +1540,7 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
|
|
*s = sumf;
|
|
|
}
|
|
|
|
|
|
-inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
|
+static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
|
const int nb = n / QK;
|
|
|
|
|
|
assert(n % QK == 0);
|
|
|
@@ -1824,7 +1824,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|
|
*s = sumf;
|
|
|
}
|
|
|
|
|
|
-inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
|
+static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
|
|
const int nb = n / QK;
|
|
|
|
|
|
const block_q4_1 * restrict x = vx;
|
|
|
@@ -6106,188 +6106,30 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|
|
//}
|
|
|
}
|
|
|
|
|
|
-static void ggml_compute_forward_mul_mat_q4_0_f32(
|
|
|
- const struct ggml_compute_params * params,
|
|
|
- const struct ggml_tensor * src0,
|
|
|
- const struct ggml_tensor * src1,
|
|
|
- struct ggml_tensor * dst) {
|
|
|
- int64_t t0 = ggml_perf_time_us();
|
|
|
- UNUSED(t0);
|
|
|
-
|
|
|
- const int ne00 = src0->ne[0];
|
|
|
- const int ne01 = src0->ne[1];
|
|
|
- const int ne02 = src0->ne[2];
|
|
|
- const int ne03 = src0->ne[3];
|
|
|
-
|
|
|
- const int ne10 = src1->ne[0];
|
|
|
- const int ne11 = src1->ne[1];
|
|
|
- const int ne12 = src1->ne[2];
|
|
|
- const int ne13 = src1->ne[3];
|
|
|
-
|
|
|
- const int ne0 = dst->ne[0];
|
|
|
- const int ne1 = dst->ne[1];
|
|
|
- const int ne2 = dst->ne[2];
|
|
|
- const int ne3 = dst->ne[3];
|
|
|
-
|
|
|
- const int nb00 = src0->nb[0];
|
|
|
- const int nb01 = src0->nb[1];
|
|
|
- const int nb02 = src0->nb[2];
|
|
|
- const int nb03 = src0->nb[3];
|
|
|
-
|
|
|
- const int nb10 = src1->nb[0];
|
|
|
- const int nb11 = src1->nb[1];
|
|
|
- const int nb12 = src1->nb[2];
|
|
|
- const int nb13 = src1->nb[3];
|
|
|
-
|
|
|
- const int nb0 = dst->nb[0];
|
|
|
- const int nb1 = dst->nb[1];
|
|
|
- const int nb2 = dst->nb[2];
|
|
|
- const int nb3 = dst->nb[3];
|
|
|
+typedef void (*dequantize_row_q_t)(const void * restrict x, float * restrict y, int k);
|
|
|
+typedef void (*quantize_row_q_t)(const float * restrict x, void * restrict y, int k);
|
|
|
+typedef void (*vec_dot_q_t)(const int n, float * restrict s, const void * restrict x, const void * restrict y);
|
|
|
|
|
|
- const int ith = params->ith;
|
|
|
- const int nth = params->nth;
|
|
|
-
|
|
|
- GGML_ASSERT(ne02 == ne12);
|
|
|
- GGML_ASSERT(ne03 == ne13);
|
|
|
- GGML_ASSERT(ne2 == ne12);
|
|
|
- GGML_ASSERT(ne3 == ne13);
|
|
|
-
|
|
|
- // we don't support permuted src0 or src1
|
|
|
- GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
|
|
|
- GGML_ASSERT(nb10 == sizeof(float));
|
|
|
-
|
|
|
- // dst cannot be transposed or permuted
|
|
|
- GGML_ASSERT(nb0 == sizeof(float));
|
|
|
- GGML_ASSERT(nb0 <= nb1);
|
|
|
- GGML_ASSERT(nb1 <= nb2);
|
|
|
- GGML_ASSERT(nb2 <= nb3);
|
|
|
-
|
|
|
- GGML_ASSERT(ne0 == ne01);
|
|
|
- GGML_ASSERT(ne1 == ne11);
|
|
|
- GGML_ASSERT(ne2 == ne02);
|
|
|
- GGML_ASSERT(ne3 == ne03);
|
|
|
-
|
|
|
- // nb01 >= nb00 - src0 is not transposed
|
|
|
- // compute by src0 rows
|
|
|
-
|
|
|
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
|
|
- if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
|
|
- if (params->ith != 0) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- if (params->type == GGML_TASK_INIT) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- if (params->type == GGML_TASK_FINALIZE) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- float * const wdata = params->wdata;
|
|
|
-
|
|
|
- for (int i03 = 0; i03 < ne03; i03++) {
|
|
|
- for (int i02 = 0; i02 < ne02; i02++) {
|
|
|
- {
|
|
|
- size_t id = 0;
|
|
|
- for (int i01 = 0; i01 < ne01; ++i01) {
|
|
|
- dequantize_row_q4_0((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
|
|
- id += ne00;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- const float * x = wdata;
|
|
|
- const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
|
|
-
|
|
|
- float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
|
|
-
|
|
|
- // zT = y * xT
|
|
|
- cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
|
|
- ne11, ne01, ne10,
|
|
|
- 1.0f, y, ne10,
|
|
|
- x, ne10,
|
|
|
- 0.0f, d, ne01);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- /*printf("CBLAS Q4_0 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
|
|
-
|
|
|
- return;
|
|
|
- }
|
|
|
-#endif
|
|
|
-
|
|
|
- if (params->type == GGML_TASK_INIT) {
|
|
|
- char * wdata = params->wdata;
|
|
|
-
|
|
|
- for (int i13 = 0; i13 < ne13; ++i13) {
|
|
|
- for (int i12 = 0; i12 < ne12; ++i12) {
|
|
|
- for (int i11 = 0; i11 < ne11; ++i11) {
|
|
|
- quantize_row_q4_0((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
|
|
- wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- if (params->type == GGML_TASK_FINALIZE) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- // parallelize by src0 rows using ggml_vec_dot_q4_0
|
|
|
-
|
|
|
- // total rows in src0
|
|
|
- const int nr = ne01*ne02*ne03;
|
|
|
-
|
|
|
- // rows per thread
|
|
|
- const int dr = (nr + nth - 1)/nth;
|
|
|
-
|
|
|
- // row range for this thread
|
|
|
- const int ir0 = dr*ith;
|
|
|
- const int ir1 = MIN(ir0 + dr, nr);
|
|
|
-
|
|
|
- void * wdata = params->wdata;
|
|
|
-
|
|
|
- for (int ir = ir0; ir < ir1; ++ir) {
|
|
|
- // src0 indices
|
|
|
- const int i03 = ir/(ne02*ne01);
|
|
|
- const int i02 = (ir - i03*ne02*ne01)/ne01;
|
|
|
- const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
|
-
|
|
|
- const int i13 = i03;
|
|
|
- const int i12 = i02;
|
|
|
-
|
|
|
- const int i0 = i01;
|
|
|
- const int i2 = i02;
|
|
|
- const int i3 = i03;
|
|
|
-
|
|
|
- void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
|
|
- char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0]);
|
|
|
-
|
|
|
- float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
|
|
-
|
|
|
- assert(ne00 % 32 == 0);
|
|
|
-
|
|
|
- for (int ic = 0; ic < ne11; ++ic) {
|
|
|
- ggml_vec_dot_q4_0(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_0])/GGML_BLCK_SIZE[GGML_TYPE_Q4_0])));
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- //int64_t t1 = ggml_time_us();
|
|
|
- //static int64_t acc = 0;
|
|
|
- //acc += t1 - t0;
|
|
|
- //if (t1 - t0 > 10) {
|
|
|
- // printf("\n");
|
|
|
- // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
|
|
|
- // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
|
|
|
- // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
|
|
|
-
|
|
|
- // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
|
|
|
- //}
|
|
|
-}
|
|
|
+typedef struct {
|
|
|
+ dequantize_row_q_t dequantize_row_q;
|
|
|
+ quantize_row_q_t quantize_row_q;
|
|
|
+ vec_dot_q_t vec_dot_q;
|
|
|
+} quantize_fns_t;
|
|
|
+
|
|
|
+static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
|
|
+ [GGML_TYPE_Q4_0] = {
|
|
|
+ .dequantize_row_q = dequantize_row_q4_0,
|
|
|
+ .quantize_row_q = quantize_row_q4_0,
|
|
|
+ .vec_dot_q = ggml_vec_dot_q4_0,
|
|
|
+ },
|
|
|
+ [GGML_TYPE_Q4_1] = {
|
|
|
+ .dequantize_row_q = dequantize_row_q4_1,
|
|
|
+ .quantize_row_q = quantize_row_q4_1,
|
|
|
+ .vec_dot_q = ggml_vec_dot_q4_1,
|
|
|
+ },
|
|
|
+};
|
|
|
|
|
|
-static void ggml_compute_forward_mul_mat_q4_1_f32(
|
|
|
+static void ggml_compute_forward_mul_mat_q_f32(
|
|
|
const struct ggml_compute_params * params,
|
|
|
const struct ggml_tensor * src0,
|
|
|
const struct ggml_tensor * src1,
|
|
|
@@ -6333,8 +6175,12 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
|
|
GGML_ASSERT(ne2 == ne12);
|
|
|
GGML_ASSERT(ne3 == ne13);
|
|
|
|
|
|
+ const enum ggml_type type = src0->type;
|
|
|
+ quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
|
|
|
+ vec_dot_q_t const vec_dot_q = quantize_fns[type].vec_dot_q;
|
|
|
+
|
|
|
// we don't support permuted src0 or src1
|
|
|
- GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
|
|
|
+ GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
|
|
GGML_ASSERT(nb10 == sizeof(float));
|
|
|
|
|
|
// dst cannot be transposed or permuted
|
|
|
@@ -6366,13 +6212,14 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
|
|
}
|
|
|
|
|
|
float * const wdata = params->wdata;
|
|
|
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
|
|
|
|
|
for (int i03 = 0; i03 < ne03; i03++) {
|
|
|
for (int i02 = 0; i02 < ne02; i02++) {
|
|
|
{
|
|
|
size_t id = 0;
|
|
|
for (int i01 = 0; i01 < ne01; ++i01) {
|
|
|
- dequantize_row_q4_1((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
|
|
+ dequantize_row_q((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01, wdata + id, ne00);
|
|
|
id += ne00;
|
|
|
}
|
|
|
}
|
|
|
@@ -6399,15 +6246,13 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
|
|
|
|
|
if (params->type == GGML_TASK_INIT) {
|
|
|
char * wdata = params->wdata;
|
|
|
+ const size_t row_size = ne10*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
|
|
|
|
|
|
for (int i13 = 0; i13 < ne13; ++i13) {
|
|
|
for (int i12 = 0; i12 < ne12; ++i12) {
|
|
|
for (int i11 = 0; i11 < ne11; ++i11) {
|
|
|
- //for (int i10 = 0; i10 < ne10; ++i10) {
|
|
|
- // wdata[id++] = GGML_FP32_TO_FP16(*(float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10));
|
|
|
- //}
|
|
|
- quantize_row_q4_1((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
|
|
- wdata += (ne10*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
|
|
|
+ quantize_row_q((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
|
|
|
+ wdata += row_size;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -6419,7 +6264,7 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- // parallelize by src0 rows using ggml_vec_dot_q4_1
|
|
|
+ // parallelize by src0 rows using ggml_vec_dot_q
|
|
|
|
|
|
// total rows in src0
|
|
|
const int nr = ne01*ne02*ne03;
|
|
|
@@ -6432,6 +6277,7 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
|
|
const int ir1 = MIN(ir0 + dr, nr);
|
|
|
|
|
|
void * wdata = params->wdata;
|
|
|
+ const size_t row_size = ne00*GGML_TYPE_SIZE[type]/GGML_BLCK_SIZE[type];
|
|
|
|
|
|
for (int ir = ir0; ir < ir1; ++ir) {
|
|
|
// src0 indices
|
|
|
@@ -6447,14 +6293,14 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
|
|
|
const int i3 = i03;
|
|
|
|
|
|
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
|
|
- char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1]);
|
|
|
+ char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
|
|
|
|
|
|
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
|
|
|
|
|
assert(ne00 % 32 == 0);
|
|
|
|
|
|
for (int ic = 0; ic < ne11; ++ic) {
|
|
|
- ggml_vec_dot_q4_1(ne00, &dst_col[ic*ne0], src0_row, ((void *) (src1_col + (ic*ne00*GGML_TYPE_SIZE[GGML_TYPE_Q4_1])/GGML_BLCK_SIZE[GGML_TYPE_Q4_1])));
|
|
|
+ vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -6478,12 +6324,9 @@ static void ggml_compute_forward_mul_mat(
|
|
|
struct ggml_tensor * dst) {
|
|
|
switch (src0->type) {
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
- {
|
|
|
- ggml_compute_forward_mul_mat_q4_0_f32(params, src0, src1, dst);
|
|
|
- } break;
|
|
|
case GGML_TYPE_Q4_1:
|
|
|
{
|
|
|
- ggml_compute_forward_mul_mat_q4_1_f32(params, src0, src1, dst);
|
|
|
+ ggml_compute_forward_mul_mat_q_f32(params, src0, src1, dst);
|
|
|
} break;
|
|
|
case GGML_TYPE_F16:
|
|
|
{
|
|
|
@@ -6644,7 +6487,7 @@ static void ggml_compute_forward_transpose(
|
|
|
|
|
|
// ggml_compute_forward_get_rows
|
|
|
|
|
|
-static void ggml_compute_forward_get_rows_q4_0(
|
|
|
+static void ggml_compute_forward_get_rows_q(
|
|
|
const struct ggml_compute_params * params,
|
|
|
const struct ggml_tensor * src0,
|
|
|
const struct ggml_tensor * src1,
|
|
|
@@ -6657,42 +6500,17 @@ static void ggml_compute_forward_get_rows_q4_0(
|
|
|
|
|
|
const int nc = src0->ne[0];
|
|
|
const int nr = ggml_nelements(src1);
|
|
|
+ const enum ggml_type type = src0->type;
|
|
|
+ dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
|
|
|
|
|
assert( dst->ne[0] == nc);
|
|
|
assert( dst->ne[1] == nr);
|
|
|
- assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_0]);
|
|
|
+ assert(src0->nb[0] == GGML_TYPE_SIZE[type]);
|
|
|
|
|
|
for (int i = 0; i < nr; ++i) {
|
|
|
const int r = ((int32_t *) src1->data)[i];
|
|
|
|
|
|
- dequantize_row_q4_0(
|
|
|
- (const void *) ((char *) src0->data + r*src0->nb[1]),
|
|
|
- (float *) ((char *) dst->data + i*dst->nb[1]), nc);
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-static void ggml_compute_forward_get_rows_q4_1(
|
|
|
- const struct ggml_compute_params * params,
|
|
|
- const struct ggml_tensor * src0,
|
|
|
- const struct ggml_tensor * src1,
|
|
|
- struct ggml_tensor * dst) {
|
|
|
- assert(params->ith == 0);
|
|
|
-
|
|
|
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- const int nc = src0->ne[0];
|
|
|
- const int nr = ggml_nelements(src1);
|
|
|
-
|
|
|
- assert( dst->ne[0] == nc);
|
|
|
- assert( dst->ne[1] == nr);
|
|
|
- assert(src0->nb[0] == GGML_TYPE_SIZE[GGML_TYPE_Q4_1]);
|
|
|
-
|
|
|
- for (int i = 0; i < nr; ++i) {
|
|
|
- const int r = ((int32_t *) src1->data)[i];
|
|
|
-
|
|
|
- dequantize_row_q4_1(
|
|
|
+ dequantize_row_q(
|
|
|
(const void *) ((char *) src0->data + r*src0->nb[1]),
|
|
|
(float *) ((char *) dst->data + i*dst->nb[1]), nc);
|
|
|
}
|
|
|
@@ -6760,12 +6578,9 @@ static void ggml_compute_forward_get_rows(
|
|
|
struct ggml_tensor * dst) {
|
|
|
switch (src0->type) {
|
|
|
case GGML_TYPE_Q4_0:
|
|
|
- {
|
|
|
- ggml_compute_forward_get_rows_q4_0(params, src0, src1, dst);
|
|
|
- } break;
|
|
|
case GGML_TYPE_Q4_1:
|
|
|
{
|
|
|
- ggml_compute_forward_get_rows_q4_1(params, src0, src1, dst);
|
|
|
+ ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
|
|
} break;
|
|
|
case GGML_TYPE_F16:
|
|
|
{
|
|
|
@@ -9098,8 +8913,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
|
|
|
|
size_t cur = 0;
|
|
|
|
|
|
- if (node->src0->type == GGML_TYPE_F16 &&
|
|
|
- node->src1->type == GGML_TYPE_F32) {
|
|
|
+ if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
|
|
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
|
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
|
|
node->n_tasks = 1; // TODO: this actually is doing nothing
|
|
|
@@ -9114,33 +8928,18 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|
|
#else
|
|
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F16]*ggml_nelements(node->src1);
|
|
|
#endif
|
|
|
- } else if (node->src0->type == GGML_TYPE_F32 &&
|
|
|
- node->src1->type == GGML_TYPE_F32) {
|
|
|
+ } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
|
|
|
cur = 0;
|
|
|
- } else if (node->src0->type == GGML_TYPE_Q4_0 &&
|
|
|
- node->src1->type == GGML_TYPE_F32) {
|
|
|
+ } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
|
|
|
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
|
|
if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
|
|
node->n_tasks = 1;
|
|
|
cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
|
|
- } else {
|
|
|
- cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
|
|
|
- }
|
|
|
-#else
|
|
|
- cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_0]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_0];
|
|
|
+ } else
|
|
|
#endif
|
|
|
- } else if (node->src0->type == GGML_TYPE_Q4_1 &&
|
|
|
- node->src1->type == GGML_TYPE_F32) {
|
|
|
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
|
|
- if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
|
|
|
- node->n_tasks = 1;
|
|
|
- cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
|
|
|
- } else {
|
|
|
- cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
|
|
|
+ {
|
|
|
+ cur = GGML_TYPE_SIZE[node->src0->type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[node->src0->type];
|
|
|
}
|
|
|
-#else
|
|
|
- cur = (GGML_TYPE_SIZE[GGML_TYPE_Q4_1]*ggml_nelements(node->src1))/GGML_BLCK_SIZE[GGML_TYPE_Q4_1];
|
|
|
-#endif
|
|
|
} else {
|
|
|
GGML_ASSERT(false);
|
|
|
}
|