|
@@ -24,6 +24,29 @@
|
|
|
|
|
|
|
|
#define UNUSED GGML_UNUSED
|
|
#define UNUSED GGML_UNUSED
|
|
|
|
|
|
|
|
|
|
+static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
|
|
|
|
|
+ int16x8_t * out_mins,
|
|
|
|
|
+ int8_t * out_scales) {
|
|
|
|
|
+ constexpr uint32_t kmask1 = 0x3f3f3f3f;
|
|
|
|
|
+ constexpr uint32_t kmask2 = 0x0f0f0f0f;
|
|
|
|
|
+ constexpr uint32_t kmask3 = 0x03030303;
|
|
|
|
|
+ constexpr uint8_t scales_size = 12;
|
|
|
|
|
+
|
|
|
|
|
+ uint32_t sm[3];
|
|
|
|
|
+ memcpy(sm, scales_in, scales_size);
|
|
|
|
|
+
|
|
|
|
|
+ const uint32_t mins_0_3 = sm[1] & kmask1;
|
|
|
|
|
+ const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4);
|
|
|
|
|
+ const uint32x2_t mins_u32 = { mins_0_3, mins_4_7 };
|
|
|
|
|
+
|
|
|
|
|
+ *out_mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins_u32)));
|
|
|
|
|
+
|
|
|
|
|
+ uint32_t scales_u32[2];
|
|
|
|
|
+ scales_u32[0] = sm[0] & kmask1;
|
|
|
|
|
+ scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
|
|
|
|
|
+ memcpy(out_scales, scales_u32, 8);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
|
|
|
assert(QK8_0 == 32);
|
|
assert(QK8_0 == 32);
|
|
|
assert(k % QK8_0 == 0);
|
|
assert(k % QK8_0 == 0);
|
|
@@ -474,6 +497,162 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|
|
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+void ggml_gemv_q4_K_8x8_q8_K(int n,
|
|
|
|
|
+ float * GGML_RESTRICT s,
|
|
|
|
|
+ size_t bs,
|
|
|
|
|
+ const void * GGML_RESTRICT vx,
|
|
|
|
|
+ const void * GGML_RESTRICT vy,
|
|
|
|
|
+ int nr,
|
|
|
|
|
+ int nc) {
|
|
|
|
|
+ constexpr int qk = QK_K;
|
|
|
|
|
+ const int nb = n / qk;
|
|
|
|
|
+
|
|
|
|
|
+ constexpr int ncols_interleaved = 8;
|
|
|
|
|
+ constexpr int blocklen = 8;
|
|
|
|
|
+
|
|
|
|
|
+ assert(n % qk == 0);
|
|
|
|
|
+ assert(nr % 4 == 0);
|
|
|
|
|
+ assert(nc % ncols_interleaved == 0);
|
|
|
|
|
+
|
|
|
|
|
+ UNUSED(nb);
|
|
|
|
|
+ UNUSED(ncols_interleaved);
|
|
|
|
|
+ UNUSED(blocklen);
|
|
|
|
|
+
|
|
|
|
|
+#if defined(__aarch64__) && defined(__ARM_NEON)
|
|
|
|
|
+ constexpr int col_pairs = ncols_interleaved / 2;
|
|
|
|
|
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
|
|
|
+
|
|
|
|
|
+ // 1x8 tile = 2 x 4
|
|
|
|
|
+ float32x4_t acc_f32[ncols_interleaved / 4];
|
|
|
|
|
+
|
|
|
|
|
+ const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
|
|
|
|
|
+
|
|
|
|
|
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
|
|
|
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 0; i < ncols_interleaved / 4; i++) {
|
|
|
|
|
+ acc_f32[i] = vdupq_n_f32(0);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (int b = 0; b < nb; b++) {
|
|
|
|
|
+ float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
|
|
|
|
|
+ float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
|
|
|
|
|
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
|
|
|
|
|
+ float32x4_t sb_scale_0 = vmulq_f32(q4_d_0, q8_d);
|
|
|
|
|
+ float32x4_t sb_scale_1 = vmulq_f32(q4_d_1, q8_d);
|
|
|
|
|
+ float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
|
|
|
|
|
+ float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
|
|
|
|
|
+ float32x4_t sb_min_0 = vmulq_f32(q4_dmin_0, q8_d);
|
|
|
|
|
+ float32x4_t sb_min_1 = vmulq_f32(q4_dmin_1, q8_d);
|
|
|
|
|
+
|
|
|
|
|
+ // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
|
|
|
|
|
+ int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
|
|
|
|
|
+ // 2 sb each iteration
|
|
|
|
|
+ int32x4_t acc_lo[col_pairs];
|
|
|
|
|
+ int32x4_t acc_hi[col_pairs];
|
|
|
|
|
+
|
|
|
|
|
+ // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
|
|
|
|
|
+ const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
|
|
|
|
|
+ int16_t bsums_arr[8];
|
|
|
|
|
+ vst1q_s16(bsums_arr, bsums);
|
|
|
|
|
+ for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
|
|
|
+ for (int i = 0; i < col_pairs; i++) {
|
|
|
|
|
+ acc_lo[i] = vdupq_n_s32(0);
|
|
|
|
|
+ acc_hi[i] = vdupq_n_s32(0);
|
|
|
|
|
+ }
|
|
|
|
|
+ // Need scales for the low and high nibbles
|
|
|
|
|
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
|
|
|
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
|
|
|
|
+ int16x8_t q4sb_scales[2];
|
|
|
|
|
+ for (int i = 0; i < 2; i++) {
|
|
|
|
|
+ int8_t aux_q4sb[8];
|
|
|
|
|
+ const int offset = sb * 24 + i * 12;
|
|
|
|
|
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
|
|
|
|
|
+ q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K;
|
|
|
|
|
+
|
|
|
|
|
+ // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns
|
|
|
|
|
+ // but still need the qs to use the low and hi bits from q4
|
|
|
|
|
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 64;
|
|
|
|
|
+ int8x16_t q8_qs[8];
|
|
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
|
|
+ q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
|
|
|
|
|
+ for (int cp = 0; cp < col_pairs; cp++) {
|
|
|
|
|
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_base + 16 * cp);
|
|
|
|
|
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_base + 16 * cp + 64);
|
|
|
|
|
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_base + 16 * cp + 128);
|
|
|
|
|
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_base + 16 * cp + 192);
|
|
|
|
|
+
|
|
|
|
|
+ acc_lo[cp] =
|
|
|
|
|
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), q8_qs[0]); // 0 .. 7
|
|
|
|
|
+ acc_lo[cp] =
|
|
|
|
|
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), q8_qs[1]); // 8 ..15
|
|
|
|
|
+ acc_lo[cp] =
|
|
|
|
|
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), q8_qs[2]); // 16..23
|
|
|
|
|
+ acc_lo[cp] =
|
|
|
|
|
+ ggml_vdotq_s32(acc_lo[cp], vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), q8_qs[3]); // 24..31
|
|
|
|
|
+
|
|
|
|
|
+ acc_hi[cp] =
|
|
|
|
|
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), q8_qs[4]); // 32..39
|
|
|
|
|
+ acc_hi[cp] =
|
|
|
|
|
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), q8_qs[5]); // 40..47
|
|
|
|
|
+ acc_hi[cp] =
|
|
|
|
|
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), q8_qs[6]); // 48..55
|
|
|
|
|
+ acc_hi[cp] =
|
|
|
|
|
+ ggml_vdotq_s32(acc_hi[cp], vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), q8_qs[7]); // 56..63
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Iterates over a pair of column pairs (4 columns) to use a single 128 register
|
|
|
|
|
+ // p = 0 -> 0123 p2 -> 4567
|
|
|
|
|
+ for (int i = 0, p = 0; p < col_pairs; i++, p += 2) {
|
|
|
|
|
+ int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q4sb_scales[0]) : vget_high_s16(q4sb_scales[0]);
|
|
|
|
|
+ int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q4sb_scales[1]) : vget_high_s16(q4sb_scales[1]);
|
|
|
|
|
+ float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
|
|
|
|
|
+
|
|
|
|
|
+ // 0123 or 4567
|
|
|
|
|
+ // TODO: Single superblock mul at the end of the superblock
|
|
|
|
|
+ float32x4_t sumf_0 =
|
|
|
|
|
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
|
|
|
|
|
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
|
|
|
|
|
+
|
|
|
|
|
+ float32x4_t sumf_1 =
|
|
|
|
|
+ vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1])));
|
|
|
|
|
+ acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Multiply Acc bsum + mins
|
|
|
|
|
+ // Each pair of subblocks share the same bsums
|
|
|
|
|
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
|
|
|
|
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
|
|
|
|
|
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
|
|
|
|
|
+
|
|
|
|
|
+ // cols 0-3 bias
|
|
|
|
|
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
|
|
|
+ bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
|
|
|
+
|
|
|
|
|
+ // cols 4-7 bias
|
|
|
|
|
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
|
|
|
+ bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
|
|
|
+ } // for sb
|
|
|
|
|
+
|
|
|
|
|
+ acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0);
|
|
|
|
|
+ acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_1);
|
|
|
|
|
+ } // for b
|
|
|
|
|
+
|
|
|
|
|
+ int base = x * ncols_interleaved;
|
|
|
|
|
+ vst1q_f32(s + base, acc_f32[0]);
|
|
|
|
|
+ vst1q_f32(s + base + 4, acc_f32[1]);
|
|
|
|
|
+ } // for x
|
|
|
|
|
+ return;
|
|
|
|
|
+#endif // defined(__aarch64__) && defined(__ARM_NEON)
|
|
|
|
|
+ ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
|
|
|
const int qk = QK8_0;
|
|
const int qk = QK8_0;
|
|
|
const int nb = n / qk;
|
|
const int nb = n / qk;
|
|
@@ -1889,3 +2068,212 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
|
|
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
|
|
|
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+void ggml_gemm_q4_K_8x8_q8_K(int n,
|
|
|
|
|
+ float * GGML_RESTRICT s,
|
|
|
|
|
+ size_t bs,
|
|
|
|
|
+ const void * GGML_RESTRICT vx,
|
|
|
|
|
+ const void * GGML_RESTRICT vy,
|
|
|
|
|
+ int nr,
|
|
|
|
|
+ int nc) {
|
|
|
|
|
+ constexpr int qk = QK_K;
|
|
|
|
|
+ const int nb = n / qk;
|
|
|
|
|
+
|
|
|
|
|
+ constexpr int ncols_interleaved = 8;
|
|
|
|
|
+ constexpr int blocklen = 8;
|
|
|
|
|
+
|
|
|
|
|
+ assert(n % qk == 0);
|
|
|
|
|
+ assert(nr % 4 == 0);
|
|
|
|
|
+ assert(nc % ncols_interleaved == 0);
|
|
|
|
|
+
|
|
|
|
|
+ UNUSED(nb);
|
|
|
|
|
+ UNUSED(ncols_interleaved);
|
|
|
|
|
+ UNUSED(blocklen);
|
|
|
|
|
+
|
|
|
|
|
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
|
|
|
+ constexpr int q8_k_blocklen = 4;
|
|
|
|
|
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
|
|
|
|
|
+
|
|
|
|
|
+ // 8 accumulators: 2 row pairs × 4 col pairs
|
|
|
|
|
+ float32x4_t acc_f32[blocklen];
|
|
|
|
|
+
|
|
|
|
|
+ for (int y = 0; y < nr / q8_k_blocklen; y++) {
|
|
|
|
|
+ const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
|
|
|
|
|
+
|
|
|
|
|
+ for (int x = 0; x < nc / ncols_interleaved; x++) {
|
|
|
|
|
+ const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 0; i < blocklen; i++) {
|
|
|
|
|
+ acc_f32[i] = vdupq_n_f32(0);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (int b = 0; b < nb; b++) {
|
|
|
|
|
+ // bsums pairs belongs to the same q8_k subblock
|
|
|
|
|
+ const int16x8_t bsums[4]{
|
|
|
|
|
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
|
|
|
|
|
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
|
|
|
|
|
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
|
|
|
|
|
+ vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
|
|
|
|
|
+ };
|
|
|
|
|
+ int16_t bsums_arr[4][8];
|
|
|
|
|
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
|
|
|
+ vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results
|
|
|
|
|
+ int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7]
|
|
|
|
|
+ int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ...
|
|
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
|
|
+ acc[i] = vdupq_n_s32(0);
|
|
|
|
|
+ bias_acc[i] = vdupq_n_s32(0);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ for (int sb = 0; sb < QK_K / 64; sb++) {
|
|
|
|
|
+ // Need scales for the low and high nibbles
|
|
|
|
|
+ // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
|
|
|
|
|
+ int8_t q4sb_scales[2][8];
|
|
|
|
|
+ int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later
|
|
|
|
|
+ for (int i = 0; i < 2; i++) {
|
|
|
|
|
+ const int offset = sb * 24 + i * 12;
|
|
|
|
|
+ decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // q8_ptr[b].qs has interleaved Q8 rows (01, 23)
|
|
|
|
|
+ const int8_t * q8_base = q8_ptr[b].qs + sb * 256;
|
|
|
|
|
+
|
|
|
|
|
+ int8x16_t q8_qs_01[8];
|
|
|
|
|
+ int8x16_t q8_qs_23[8];
|
|
|
|
|
+
|
|
|
|
|
+ // Load 32-byte per row pair, 1 subblock each time
|
|
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
|
|
+ const int offset = i * 32; // 16 for row 01, 16 for row 23
|
|
|
|
|
+ q8_qs_01[i] = vld1q_s8(q8_base + offset);
|
|
|
|
|
+ q8_qs_23[i] = vld1q_s8(q8_base + offset + 16);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const int8x16_t q8s[2][8] = {
|
|
|
|
|
+ { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3],
|
|
|
|
|
+ q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] },
|
|
|
|
|
+ { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3],
|
|
|
|
|
+ q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] },
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Q4s columns iterated in pairs (01, 23, 45, 67)
|
|
|
|
|
+ for (int cp = 0; cp < ncols_interleaved / 2; cp++) {
|
|
|
|
|
+ for (int i = 0; i < 4; i++) {
|
|
|
|
|
+ sb_acc[i] = vdupq_n_s32(0);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39
|
|
|
|
|
+ uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47
|
|
|
|
|
+ uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55
|
|
|
|
|
+ uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63
|
|
|
|
|
+ const int8x16_t q4_nibbles[2][4] = {
|
|
|
|
|
+ {
|
|
|
|
|
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)),
|
|
|
|
|
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)),
|
|
|
|
|
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)),
|
|
|
|
|
+ vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)),
|
|
|
|
|
+ },
|
|
|
|
|
+ {
|
|
|
|
|
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)),
|
|
|
|
|
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)),
|
|
|
|
|
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)),
|
|
|
|
|
+ vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)),
|
|
|
|
|
+ }
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8
|
|
|
|
|
+ // for each of the internal 32 qs subblock (blk)
|
|
|
|
|
+ for (int rp = 0; rp < 2; rp++) {
|
|
|
|
|
+ for (int blk = 0; blk < 2; blk++) {
|
|
|
|
|
+ const int8x16_t * q8 = &q8s[rp][4 * blk];
|
|
|
|
|
+ const int8x16_t * q4 = q4_nibbles[blk];
|
|
|
|
|
+ int32x4_t acc = sb_acc[2 * rp + blk];
|
|
|
|
|
+ // mul add for each qs in the same subblock
|
|
|
|
|
+ for (int qs_offset = 0; qs_offset < 4; qs_offset++) {
|
|
|
|
|
+ acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]);
|
|
|
|
|
+ }
|
|
|
|
|
+ sb_acc[2 * rp + blk] = acc;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Scales[i] corresponds to column i
|
|
|
|
|
+ const int scale_offset = cp * 2;
|
|
|
|
|
+ for (int blk = 0; blk < 2; blk++) {
|
|
|
|
|
+ const int32x4_t block_scale = {
|
|
|
|
|
+ (int32_t) q4sb_scales[blk][scale_offset],
|
|
|
|
|
+ (int32_t) q4sb_scales[blk][scale_offset],
|
|
|
|
|
+ (int32_t) q4sb_scales[blk][scale_offset + 1],
|
|
|
|
|
+ (int32_t) q4sb_scales[blk][scale_offset + 1],
|
|
|
|
|
+ };
|
|
|
|
|
+ acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale);
|
|
|
|
|
+ acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Multiply Acc bsum + mins
|
|
|
|
|
+ for (int q8_row = 0; q8_row < 4; q8_row++) {
|
|
|
|
|
+ // Each pair of subblocks share the same bsums
|
|
|
|
|
+ // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)).
|
|
|
|
|
+ int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]);
|
|
|
|
|
+ int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]);
|
|
|
|
|
+
|
|
|
|
|
+ bias_acc[2 * q8_row] =
|
|
|
|
|
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
|
|
|
|
|
+ bias_acc[2 * q8_row] =
|
|
|
|
|
+ vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
|
|
|
|
|
+ bias_acc[2 * q8_row + 1] =
|
|
|
|
|
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
|
|
|
|
|
+ bias_acc[2 * q8_row + 1] =
|
|
|
|
|
+ vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
|
|
|
|
|
+ }
|
|
|
|
|
+ } // for sb
|
|
|
|
|
+
|
|
|
|
|
+ // Reorder of i8mm output with bias and output layout
|
|
|
|
|
+ for (int i = 0; i < 8; i++) {
|
|
|
|
|
+ int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i]));
|
|
|
|
|
+ acc[i] = vcombine_s32(aux.val[0], aux.val[1]);
|
|
|
|
|
+ }
|
|
|
|
|
+ int32x4_t reorder_acc[8] = {
|
|
|
|
|
+ vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])),
|
|
|
|
|
+ vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])),
|
|
|
|
|
+ vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])),
|
|
|
|
|
+ vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])),
|
|
|
|
|
+ vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])),
|
|
|
|
|
+ vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])),
|
|
|
|
|
+ vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])),
|
|
|
|
|
+ vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])),
|
|
|
|
|
+ };
|
|
|
|
|
+
|
|
|
|
|
+ for (int i = 0; i < q8_k_blocklen; i++) {
|
|
|
|
|
+ for (int j = 0; j < 2; j++) {
|
|
|
|
|
+ float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]);
|
|
|
|
|
+ float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4)));
|
|
|
|
|
+ const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d);
|
|
|
|
|
+
|
|
|
|
|
+ float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4)));
|
|
|
|
|
+ const float32x4_t scale = vmulq_f32(q4_d, q8_d);
|
|
|
|
|
+
|
|
|
|
|
+ acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins);
|
|
|
|
|
+ acc_f32[2 * i + j] =
|
|
|
|
|
+ vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ } // for b
|
|
|
|
|
+
|
|
|
|
|
+ // With the previous reorder, the tile is already in the correct memory layout.
|
|
|
|
|
+ for (int i = 0; i < q8_k_blocklen; i++) {
|
|
|
|
|
+ int row = y * q8_k_blocklen + i;
|
|
|
|
|
+ for (int j = 0; j < 2; j++) {
|
|
|
|
|
+ int col = x * ncols_interleaved + j * 4;
|
|
|
|
|
+ int offset = row * bs + col;
|
|
|
|
|
+ vst1q_f32(s + offset, acc_f32[2 * i + j]);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ } // for x
|
|
|
|
|
+ } // for y
|
|
|
|
|
+ return;
|
|
|
|
|
+#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
|
|
|
|
+ ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
|
|
|
|
|
+}
|