|
|
@@ -278,6 +278,72 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
|
|
|
#endif
|
|
|
}
|
|
|
|
|
|
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
|
+ assert(nrc == 1);
|
|
|
+ UNUSED(nrc);
|
|
|
+ UNUSED(bx);
|
|
|
+ UNUSED(by);
|
|
|
+ UNUSED(bs);
|
|
|
+ assert(n % QK_MXFP4 == 0);
|
|
|
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
|
|
|
+
|
|
|
+ const block_mxfp4 * GGML_RESTRICT x = vx;
|
|
|
+ const block_q8_0 * GGML_RESTRICT y = vy;
|
|
|
+
|
|
|
+ const int nb = n / QK_MXFP4;
|
|
|
+
|
|
|
+ int ib = 0;
|
|
|
+ float sumf = 0;
|
|
|
+
|
|
|
+#if defined(__POWER9_VECTOR__)
|
|
|
+ const vector signed char lowMask = vec_splats((signed char)0xF);
|
|
|
+ const vector unsigned char vshift4 = vec_splats((unsigned char)4);
|
|
|
+ vector float vsumf0 = vec_splats(0.0f);
|
|
|
+
|
|
|
+ vector signed char kv = vec_xl(0, (const signed char *)kvalues_mxfp4);
|
|
|
+
|
|
|
+#pragma GCC unroll 8
|
|
|
+ for (; ib < nb; ++ib) {
|
|
|
+ __builtin_prefetch(x[ib].qs, 0, 1);
|
|
|
+ __builtin_prefetch(y[ib].qs, 0, 1);
|
|
|
+
|
|
|
+ vector float vyd = vec_splats(GGML_CPU_FP16_TO_FP32(y[ib].d) *
|
|
|
+ GGML_E8M0_TO_FP32_HALF(x[ib].e));
|
|
|
+
|
|
|
+ vector signed char q8y0 = vec_xl( 0, y[ib].qs);
|
|
|
+ vector signed char q8y1 = vec_xl(16, y[ib].qs);
|
|
|
+
|
|
|
+ vector signed char qxs = (vector signed char)vec_xl(0, x[ib].qs);
|
|
|
+
|
|
|
+ vector unsigned char lo_nibbles = (vector unsigned char)vec_and(qxs, lowMask);
|
|
|
+ vector unsigned char hi_nibbles = (vector unsigned char)vec_sr(qxs, vshift4);
|
|
|
+
|
|
|
+ vector signed char q4x0 = vec_perm(kv, kv, lo_nibbles);
|
|
|
+ vector signed char q4x1 = vec_perm(kv, kv, hi_nibbles);
|
|
|
+
|
|
|
+ vector signed short qv0 = vec_add(vec_mule(q4x0, q8y0), vec_mulo(q4x0, q8y0));
|
|
|
+ vector signed short qv1 = vec_add(vec_mule(q4x1, q8y1), vec_mulo(q4x1, q8y1));
|
|
|
+
|
|
|
+ vector signed int vsumi0 = vec_splats((int32_t)0);
|
|
|
+ vsumi0 = vec_sum4s(qv0, vsumi0);
|
|
|
+ vsumi0 = vec_sum4s(qv1, vsumi0);
|
|
|
+
|
|
|
+ vsumf0 = vec_madd(vec_ctf(vsumi0, 0), vyd, vsumf0);
|
|
|
+ }
|
|
|
+
|
|
|
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 4));
|
|
|
+ vsumf0 = vec_add(vsumf0, vec_sld(vsumf0, vsumf0, 8));
|
|
|
+ sumf = vec_extract(vsumf0, 0);
|
|
|
+ *s = sumf;
|
|
|
+#else
|
|
|
+ UNUSED(x);
|
|
|
+ UNUSED(y);
|
|
|
+ UNUSED(ib);
|
|
|
+ UNUSED(sumf);
|
|
|
+ ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc);
|
|
|
+#endif
|
|
|
+}
|
|
|
+
|
|
|
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|