|
|
@@ -497,6 +497,140 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|
|
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
}
|
|
|
|
|
|
+void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
+ constexpr int qk = QK_K;
|
|
|
+ const int nb = n / qk;
|
|
|
+
|
|
|
+ constexpr int ncols_interleaved = 8;
|
|
|
+ constexpr int blocklen = 8;
|
|
|
+
|
|
|
+ assert(n % qk == 0);
|
|
|
+ assert(nr % 4 == 0);
|
|
|
+ assert(nc % ncols_interleaved == 0);
|
|
|
+
|
|
|
+ UNUSED(nb);
|
|
|
+ UNUSED(ncols_interleaved);
|
|
|
+ UNUSED(blocklen);
|
|
|
+
|
|
|
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
|
|
|
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
|
+
|
|
|
+ // 1x8 tile = 2 x 4
|
|
|
+ float32x4_t acc_f32[col_groups];
|
|
|
+
|
|
|
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
|
|
+
|
|
|
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
|
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
|
+
|
|
|
+ for (int i = 0; i < col_groups; i++) {
|
|
|
+ acc_f32[i] = vdupq_n_f32(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int b = 0; b < nb; b++) {
|
|
|
+ float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
|
|
|
+ float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
|
|
|
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
|
|
+ float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
|
|
|
+ float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
|
|
|
+ float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
|
|
|
+ float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
|
|
|
+ float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
|
|
|
+ float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
|
|
|
+
|
|
|
+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
|
|
+ int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
|
|
+ int32x4_t acc_lo[col_groups];
|
|
|
+ int32x4_t acc_hi[col_groups];
|
|
|
+
|
|
|
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
|
|
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
|
|
+ int16_t bsums_arr[8];
|
|
|
+ vst1q_s16(bsums_arr, bsums);
|
|
|
+ for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
|
+ for (int i = 0; i < col_groups; i++) {
|
|
|
+ acc_lo[i] = vdupq_n_s32(0);
|
|
|
+ acc_hi[i] = vdupq_n_s32(0);
|
|
|
+ }
|
|
|
+ // Need scales for the low and high nibbles
|
|
|
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
|
+ int16x8_t q4sb_mins[2];
|
|
|
+ int16x8_t q4sb_scales[2];
|
|
|
+ for (int i = 0; i < 2; i++) {
|
|
|
+ int8_t aux_q4sb[8];
|
|
|
+ const int offset = sb * 24 + i * 12;
|
|
|
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
|
|
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
|
+ }
|
|
|
+
|
|
|
+ int8x16_t q8_qs[64 / 16];
|
|
|
+ for (int i = 0; i < 64 / 16; i++) {
|
|
|
+ q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int c = 0; c < col_groups; c++) {
|
|
|
+ uint8x16_t q4_cols[8];
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
+ q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
|
|
|
+ }
|
|
|
+
|
|
|
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
|
|
|
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
|
|
|
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
|
|
|
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
|
|
|
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
|
|
|
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
|
|
|
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
|
|
|
+ acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
|
|
|
+
|
|
|
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
|
|
|
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
|
|
|
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
|
|
|
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
|
|
|
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
|
|
|
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
|
|
|
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
|
|
|
+ acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Scales
|
|
|
+ // row c0123 blk0 and blk1
|
|
|
+ const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
|
|
|
+ const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
|
|
|
+ const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
|
|
|
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
|
|
|
+ acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
|
|
|
+ // row c4567 blk0 and blk1
|
|
|
+ const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
|
|
|
+ const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
|
|
|
+ const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
|
|
|
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
|
|
|
+ acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
|
|
|
+
|
|
|
+ // Bias Correction
|
|
|
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
|
|
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
|
|
+
|
|
|
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
|
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
|
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
|
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
|
+ } // for sb
|
|
|
+
|
|
|
+ acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
|
|
|
+ acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
|
|
|
+ } // for b
|
|
|
+
|
|
|
+ int base = x * ncols_interleaved;
|
|
|
+ vst1q_f32(s + base, acc_f32[0]);
|
|
|
+ vst1q_f32(s + base + 4, acc_f32[1]);
|
|
|
+ } // for x
|
|
|
+ return;
|
|
|
+#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|
|
float * GGML_RESTRICT s,
|
|
|
size_t bs,
|
|
|
@@ -518,7 +652,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|
|
UNUSED(ncols_interleaved);
|
|
|
UNUSED(blocklen);
|
|
|
|
|
|
-#if defined(__aarch64__) && defined(__ARM_NEON)
|
|
|
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
constexpr int col_pairs = ncols_interleaved / 2;
|
|
|
const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
|
|
|
|
@@ -615,7 +749,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|
|
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
|
|
|
|
|
// 0123 or 4567
|
|
|
- // TODO: Single superblock mul at the end of the superblock
|
|
|
float32x4_t sumf_0 =
|
|
|
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
|
|
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
|
|
@@ -649,7 +782,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|
|
vst1q_f32(s + base + 4, acc_f32[1]);
|
|
|
} // for x
|
|
|
return;
|
|
|
-#endif // defined(__aarch64__) && defined(__ARM_NEON)
|
|
|
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
}
|
|
|
|
|
|
@@ -2069,6 +2202,206 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|
|
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
}
|
|
|
|
|
|
+void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
+ constexpr int qk = QK_K;
|
|
|
+ const int nb = n / qk;
|
|
|
+
|
|
|
+ constexpr int ncols_interleaved = 8;
|
|
|
+ constexpr int blocklen = 4;
|
|
|
+
|
|
|
+ assert(n % qk == 0);
|
|
|
+ assert(nr % 4 == 0);
|
|
|
+ assert(nc % ncols_interleaved == 0);
|
|
|
+
|
|
|
+ UNUSED(nb);
|
|
|
+ UNUSED(ncols_interleaved);
|
|
|
+ UNUSED(blocklen);
|
|
|
+
|
|
|
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ constexpr int q8_k_blocklen = 4;
|
|
|
+ constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
|
|
|
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
|
+
|
|
|
+ // 8 accumulators: 2 row pairs × 4 col pairs
|
|
|
+ float32x4_t acc_f32[acc_size];
|
|
|
+
|
|
|
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
|
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
|
+
|
|
|
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
|
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
|
+
|
|
|
+ for (int i = 0; i < acc_size; i++) {
|
|
|
+ acc_f32[i] = vdupq_n_f32(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int b = 0; b < nb; b++) {
|
|
|
+ // d4 0 1 2 3, 4 5 6 7
|
|
|
+ float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
|
|
|
+ float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
|
|
|
+ // d8 0 1 2 3
|
|
|
+ float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
|
|
|
+ // mins
|
|
|
+ float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
|
|
|
+ float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
|
|
|
+
|
|
|
+ // Precomputation of scales and mins
|
|
|
+ float32x4_t sbd_scale_0123[q8_k_blocklen];
|
|
|
+ float32x4_t sbd_scale_4567[q8_k_blocklen];
|
|
|
+ float32x4_t sbd_min_0123[q8_k_blocklen];
|
|
|
+ float32x4_t sbd_min_4567[q8_k_blocklen];
|
|
|
+
|
|
|
+ sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
|
|
|
+ sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
|
|
|
+ sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
|
|
|
+ sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
|
|
|
+
|
|
|
+ sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
|
|
|
+ sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
|
|
|
+ sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
|
|
|
+ sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
|
|
|
+
|
|
|
+ sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
|
|
|
+ sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
|
|
|
+ sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
|
|
|
+ sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
|
|
|
+
|
|
|
+ sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
|
|
|
+ sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
|
|
|
+ sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
|
|
|
+ sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
|
|
|
+
|
|
|
+ // Precomputation of bsums, each vpaddq calcs all the bsums for each row
|
|
|
+ const int16x8_t bsums[q8_k_blocklen] = {
|
|
|
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
|
|
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
|
|
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
|
|
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
|
|
+ };
|
|
|
+ int16_t bsums_arr[QK_K / 64][8];
|
|
|
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
|
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
|
|
+ }
|
|
|
+
|
|
|
+ // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
|
|
|
+ int32x4_t bias_acc[acc_size];
|
|
|
+ for (int i = 0; i < acc_size; i++) {
|
|
|
+ bias_acc[i] = vdupq_n_s32(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
|
+ // Int accumulators for qs vecdot (4 row x 2 col quartets)
|
|
|
+ int32x4_t acc_lo[acc_size];
|
|
|
+ int32x4_t acc_hi[acc_size];
|
|
|
+ for (int i = 0; i < acc_size; i++) {
|
|
|
+ acc_lo[i] = vdupq_n_s32(0);
|
|
|
+ acc_hi[i] = vdupq_n_s32(0);
|
|
|
+ }
|
|
|
+ // Need scales for the low and high nibbles
|
|
|
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
|
+ int16x8_t q4sb_scales[2];
|
|
|
+ int16x8_t q4sb_mins[2];
|
|
|
+ for (int i = 0; i < 2; i++) {
|
|
|
+ int8_t aux_q4sb[8];
|
|
|
+ const int offset = sb * 24 + i * 12;
|
|
|
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
|
|
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
|
+ }
|
|
|
+
|
|
|
+ constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
|
|
|
+ for (int k = 0; k < reads_per_sb; k++) {
|
|
|
+ const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
|
|
|
+ const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
|
|
|
+
|
|
|
+ // 0..3 & 32..35
|
|
|
+ const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
|
|
|
+ const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
|
|
|
+
|
|
|
+ const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
|
|
|
+ const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
|
|
|
+
|
|
|
+ acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
|
|
|
+ acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
|
|
|
+ acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
|
|
|
+ acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
|
|
|
+
|
|
|
+ acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
|
|
|
+ acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
|
|
|
+ acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
|
|
|
+ acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
|
|
|
+
|
|
|
+ const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
|
|
|
+ const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
|
|
|
+
|
|
|
+ acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
|
|
|
+ acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
|
|
|
+ acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
|
|
|
+ acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
|
|
|
+
|
|
|
+ acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
|
|
|
+ acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
|
|
|
+ acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
|
|
|
+ acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
|
|
|
+ }
|
|
|
+
|
|
|
+ // Scale and bias application
|
|
|
+ // acc is stored interleaved to match output layout
|
|
|
+ const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
|
|
|
+ const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
|
|
|
+ const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
|
|
|
+ const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
|
|
|
+ for (int row = 0; row < q8_k_blocklen; row++) {
|
|
|
+ // Bias correction
|
|
|
+ // row c0123 blk0 and blk1
|
|
|
+ const float32x4_t sumf_0123 =
|
|
|
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
|
|
|
+ vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
|
|
|
+ acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
|
|
|
+
|
|
|
+ // row c4567 blk0 and blk1
|
|
|
+ const float32x4_t sumf_4567 =
|
|
|
+ vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
|
|
|
+ vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
|
|
|
+ acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
|
|
|
+
|
|
|
+ // Bias
|
|
|
+ const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
|
|
|
+ const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
|
|
|
+
|
|
|
+ // row c0123 blk0 and blk1
|
|
|
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
|
+ bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
|
+
|
|
|
+ // row c4567 blk0 and blk1
|
|
|
+ bias_acc[2 * row + 1] =
|
|
|
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
|
+ bias_acc[2 * row + 1] =
|
|
|
+ vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
|
+ }
|
|
|
+ } // for sb
|
|
|
+
|
|
|
+ for (int row = 0; row < q8_k_blocklen; row++) {
|
|
|
+ acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
|
|
|
+ acc_f32[2 * row + 1] =
|
|
|
+ vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
|
|
|
+ }
|
|
|
+ } // for b
|
|
|
+
|
|
|
+ for (int i = 0; i < q8_k_blocklen; i++) {
|
|
|
+ int row = y * q8_k_blocklen + i;
|
|
|
+ for (int j = 0; j < 2; j++) {
|
|
|
+ int col = x * ncols_interleaved + j * 4;
|
|
|
+ int offset = row * bs + col;
|
|
|
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } // for x
|
|
|
+ } // for y
|
|
|
+ return;
|
|
|
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
|
|
|
+ ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
+}
|
|
|
+
|
|
|
void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
float * GGML_RESTRICT s,
|
|
|
size_t bs,
|