|
|
@@ -786,6 +786,133 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|
|
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
}
|
|
|
|
|
|
+void ggml_gemv_q8_0_4x4_q8_0(int n,
|
|
|
+ float * GGML_RESTRICT s,
|
|
|
+ size_t bs,
|
|
|
+ const void * GGML_RESTRICT vx,
|
|
|
+ const void * GGML_RESTRICT vy,
|
|
|
+ int nr,
|
|
|
+ int nc) {
|
|
|
+ const int qk = QK8_0;
|
|
|
+ const int nb = n / qk;
|
|
|
+ const int ncols_interleaved = 4;
|
|
|
+ const int blocklen = 4;
|
|
|
+
|
|
|
+ assert(n % qk == 0);
|
|
|
+ assert(nc % ncols_interleaved == 0);
|
|
|
+
|
|
|
+ UNUSED(nb);
|
|
|
+ UNUSED(ncols_interleaved);
|
|
|
+ UNUSED(blocklen);
|
|
|
+
|
|
|
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
|
|
|
+
|
|
|
+ for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
|
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
|
+ float32x4_t acc = vdupq_n_f32(0);
|
|
|
+ for (int b = 0; b < nb; b++) {
|
|
|
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
|
|
|
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
|
|
|
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
|
+
|
|
|
+ int8x16x2_t a = vld1q_s8_x2(a_ptr->qs);
|
|
|
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
|
+
|
|
|
+ int32x4_t ret = vdupq_n_s32(0);
|
|
|
+
|
|
|
+ ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0);
|
|
|
+ ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1);
|
|
|
+ ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2);
|
|
|
+ ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3);
|
|
|
+
|
|
|
+ ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0);
|
|
|
+ ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1);
|
|
|
+ ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2);
|
|
|
+ ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3);
|
|
|
+
|
|
|
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
|
+ a_ptr++;
|
|
|
+ b_ptr++;
|
|
|
+ }
|
|
|
+ vst1q_f32(s, acc);
|
|
|
+ s += ncols_interleaved;
|
|
|
+ }
|
|
|
+ return;
|
|
|
+
|
|
|
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+void ggml_gemv_q8_0_4x8_q8_0(int n,
|
|
|
+ float * GGML_RESTRICT s,
|
|
|
+ size_t bs,
|
|
|
+ const void * GGML_RESTRICT vx,
|
|
|
+ const void * GGML_RESTRICT vy,
|
|
|
+ int nr,
|
|
|
+ int nc) {
|
|
|
+ const int qk = QK8_0;
|
|
|
+ const int nb = n / qk;
|
|
|
+ const int ncols_interleaved = 4;
|
|
|
+ const int blocklen = 8;
|
|
|
+
|
|
|
+ assert(n % qk == 0);
|
|
|
+ assert(nc % ncols_interleaved == 0);
|
|
|
+
|
|
|
+ UNUSED(nb);
|
|
|
+ UNUSED(ncols_interleaved);
|
|
|
+ UNUSED(blocklen);
|
|
|
+
|
|
|
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx;
|
|
|
+
|
|
|
+ for (int c = 0; c < nc; c += ncols_interleaved) {
|
|
|
+ const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
|
|
|
+ float32x4_t acc = vdupq_n_f32(0);
|
|
|
+
|
|
|
+ for (int b = 0; b < nb; b++) {
|
|
|
+ int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs);
|
|
|
+ int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64);
|
|
|
+ float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d);
|
|
|
+
|
|
|
+ int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs);
|
|
|
+ int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]);
|
|
|
+ int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]);
|
|
|
+ int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]);
|
|
|
+ int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]);
|
|
|
+ float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d);
|
|
|
+
|
|
|
+ int32x4_t ret0 = vdupq_n_s32(0);
|
|
|
+ int32x4_t ret1 = vdupq_n_s32(0);
|
|
|
+
|
|
|
+ // 0..7
|
|
|
+ ret0 = vdotq_s32(ret0, b_low.val[0], a0);
|
|
|
+ ret1 = vdotq_s32(ret1, b_low.val[1], a0);
|
|
|
+ // 8..15
|
|
|
+ ret0 = vdotq_s32(ret0, b_low.val[2], a1);
|
|
|
+ ret1 = vdotq_s32(ret1, b_low.val[3], a1);
|
|
|
+ // 16..23
|
|
|
+ ret0 = vdotq_s32(ret0, b_high.val[0], a2);
|
|
|
+ ret1 = vdotq_s32(ret1, b_high.val[1], a2);
|
|
|
+ // 24..31
|
|
|
+ ret0 = vdotq_s32(ret0, b_high.val[2], a3);
|
|
|
+ ret1 = vdotq_s32(ret1, b_high.val[3], a3);
|
|
|
+
|
|
|
+ int32x4_t ret = vpaddq_s32(ret0, ret1);
|
|
|
+
|
|
|
+ acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd)));
|
|
|
+ a_ptr++;
|
|
|
+ b_ptr++;
|
|
|
+ }
|
|
|
+ vst1q_f32(s, acc);
|
|
|
+ s += ncols_interleaved;
|
|
|
+ }
|
|
|
+ return;
|
|
|
+
|
|
|
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
|
@@ -2610,3 +2737,159 @@ void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
|
ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
}
|
|
|
+
|
|
|
+
|
|
|
+void ggml_gemm_q8_0_4x4_q8_0(int n,
|
|
|
+ float * GGML_RESTRICT s,
|
|
|
+ size_t bs,
|
|
|
+ const void * GGML_RESTRICT vx,
|
|
|
+ const void * GGML_RESTRICT vy,
|
|
|
+ int nr,
|
|
|
+ int nc) {
|
|
|
+ const int qk = QK8_0;
|
|
|
+ const int nb = n / qk;
|
|
|
+ const int ncols_interleaved = 4;
|
|
|
+ const int blocklen = 4;
|
|
|
+
|
|
|
+ assert(n % qk == 0);
|
|
|
+ assert(nr % 4 == 0);
|
|
|
+ assert(nc % ncols_interleaved == 0);
|
|
|
+
|
|
|
+ UNUSED(nb);
|
|
|
+ UNUSED(ncols_interleaved);
|
|
|
+ UNUSED(blocklen);
|
|
|
+
|
|
|
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ for (int y = 0; y < nr / 4; y++) {
|
|
|
+ const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
|
|
|
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
|
+ const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb);
|
|
|
+
|
|
|
+ float32x4_t sumf[4];
|
|
|
+ for (int m = 0; m < 4; m++) {
|
|
|
+ sumf[m] = vdupq_n_f32(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int l = 0; l < nb; l++) {
|
|
|
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *) a_ptr[l].d));
|
|
|
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *) b_ptr[l].d));
|
|
|
+
|
|
|
+ int32x4_t sumi_0 = vdupq_n_s32(0);
|
|
|
+ int32x4_t sumi_1 = vdupq_n_s32(0);
|
|
|
+ int32x4_t sumi_2 = vdupq_n_s32(0);
|
|
|
+ int32x4_t sumi_3 = vdupq_n_s32(0);
|
|
|
+
|
|
|
+ for (int k_group = 0; k_group < 8; k_group += 4) {
|
|
|
+ int8x16x4_t a = vld1q_s8_x4(a_ptr[l].qs + 16 * k_group);
|
|
|
+ int8x16x4_t b = vld1q_s8_x4(b_ptr[l].qs + 16 * k_group);
|
|
|
+
|
|
|
+ for (int k = 0; k < 4; k++) {
|
|
|
+ sumi_0 = vdotq_laneq_s32(sumi_0, b.val[k], a.val[k], 0);
|
|
|
+ sumi_1 = vdotq_laneq_s32(sumi_1, b.val[k], a.val[k], 1);
|
|
|
+ sumi_2 = vdotq_laneq_s32(sumi_2, b.val[k], a.val[k], 2);
|
|
|
+ sumi_3 = vdotq_laneq_s32(sumi_3, b.val[k], a.val[k], 3);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
|
|
|
+ sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
|
|
|
+ sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
|
|
|
+ sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int m = 0; m < 4; m++) {
|
|
|
+ vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return;
|
|
|
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ ggml_gemm_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
+void ggml_gemm_q8_0_4x8_q8_0(int n,
|
|
|
+ float * GGML_RESTRICT s,
|
|
|
+ size_t bs,
|
|
|
+ const void * GGML_RESTRICT vx,
|
|
|
+ const void * GGML_RESTRICT vy,
|
|
|
+ int nr,
|
|
|
+ int nc) {
|
|
|
+ const int qk = QK8_0;
|
|
|
+ const int nb = n / qk;
|
|
|
+ const int ncols_interleaved = 4;
|
|
|
+ const int blocklen = 8;
|
|
|
+
|
|
|
+ assert(n % qk == 0);
|
|
|
+ assert(nr % 4 == 0);
|
|
|
+ assert(nc % ncols_interleaved == 0);
|
|
|
+
|
|
|
+ UNUSED(nb);
|
|
|
+ UNUSED(ncols_interleaved);
|
|
|
+ UNUSED(blocklen);
|
|
|
+
|
|
|
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
|
+ const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx;
|
|
|
+
|
|
|
+ for (int y = 0; y < nr; y += 4) {
|
|
|
+ const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb;
|
|
|
+
|
|
|
+ for (int x = 0; x < nc; x += ncols_interleaved) {
|
|
|
+ const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb;
|
|
|
+ const block_q8_0x4 * a_ptr = a_ptr_base;
|
|
|
+
|
|
|
+ float32x4_t acc_f32[4];
|
|
|
+ for (int i = 0; i < 4; i++) {
|
|
|
+ acc_f32[i] = vdupq_n_f32(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int b = 0; b < nb; b++) {
|
|
|
+ int32x4_t acc[4];
|
|
|
+ for (int i = 0; i < 4; i++) {
|
|
|
+ acc[i] = vdupq_n_s32(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Process 4 chunks of 8 positions each
|
|
|
+ for (int chunk = 0; chunk < 4; chunk++) {
|
|
|
+ int8x16_t a01 = vld1q_s8(a_ptr->qs + chunk * 32);
|
|
|
+ int8x16_t a23 = vld1q_s8(a_ptr->qs + chunk * 32 + 16);
|
|
|
+ int8x16_t b01 = vld1q_s8(b_ptr->qs + chunk * 32);
|
|
|
+ int8x16_t b23 = vld1q_s8(b_ptr->qs + chunk * 32 + 16);
|
|
|
+
|
|
|
+ acc[0] = vmmlaq_s32(acc[0], a01, b01);
|
|
|
+ acc[1] = vmmlaq_s32(acc[1], a01, b23);
|
|
|
+ acc[2] = vmmlaq_s32(acc[2], a23, b01);
|
|
|
+ acc[3] = vmmlaq_s32(acc[3], a23, b23);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Reorder outputs from 2×2 tiles to row-major
|
|
|
+ // acc[0] = [r0c0, r0c1, r1c0, r1c1]
|
|
|
+ // acc[1] = [r0c2, r0c3, r1c2, r1c3]
|
|
|
+ // acc[2] = [r2c0, r2c1, r3c0, r3c1]
|
|
|
+ // acc[3] = [r2c2, r2c3, r3c2, r3c3]
|
|
|
+ int32x4_t row0 = vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1]));
|
|
|
+ int32x4_t row1 = vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1]));
|
|
|
+ int32x4_t row2 = vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3]));
|
|
|
+ int32x4_t row3 = vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3]));
|
|
|
+
|
|
|
+ // Scales
|
|
|
+ float32x4_t a_d = vcvt_f32_f16(vld1_f16((const __fp16 *) a_ptr->d));
|
|
|
+ float32x4_t b_d = vcvt_f32_f16(vld1_f16((const __fp16 *) b_ptr->d));
|
|
|
+
|
|
|
+ acc_f32[0] = vfmaq_f32(acc_f32[0], vcvtq_f32_s32(row0), vmulq_laneq_f32(b_d, a_d, 0));
|
|
|
+ acc_f32[1] = vfmaq_f32(acc_f32[1], vcvtq_f32_s32(row1), vmulq_laneq_f32(b_d, a_d, 1));
|
|
|
+ acc_f32[2] = vfmaq_f32(acc_f32[2], vcvtq_f32_s32(row2), vmulq_laneq_f32(b_d, a_d, 2));
|
|
|
+ acc_f32[3] = vfmaq_f32(acc_f32[3], vcvtq_f32_s32(row3), vmulq_laneq_f32(b_d, a_d, 3));
|
|
|
+
|
|
|
+ a_ptr++;
|
|
|
+ b_ptr++;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int row = 0; row < 4; row++) {
|
|
|
+ vst1q_f32(s + (y + row) * bs + x, acc_f32[row]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return;
|
|
|
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
|
+ ggml_gemm_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|