|
@@ -361,7 +361,7 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
|
|
|
|
|
|
|
// AVX routines provided by GH user Const-me
|
|
// AVX routines provided by GH user Const-me
|
|
|
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
|
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
|
|
-#if __AVX2__
|
|
|
|
|
|
|
+#if __AVX2__ || __AVX512F__
|
|
|
// Unpack 32 4-bit fields into 32 bytes
|
|
// Unpack 32 4-bit fields into 32 bytes
|
|
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
|
|
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
|
static inline __m256i bytesFromNibbles( const uint8_t* rsi )
|
|
@@ -397,7 +397,6 @@ static inline __m128i packNibbles( __m256i bytes )
|
|
|
}
|
|
}
|
|
|
#endif
|
|
#endif
|
|
|
|
|
|
|
|
-
|
|
|
|
|
// method 5
|
|
// method 5
|
|
|
// blocks of QK elements
|
|
// blocks of QK elements
|
|
|
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
|
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
|
@@ -1262,6 +1261,47 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
|
|
*s = sumf;
|
|
*s = sumf;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+#if __AVX512F__ && QK == 32
|
|
|
|
|
+static inline __m512 dot_q4_0_oneblock_avx512(
|
|
|
|
|
+ __m512 acc,
|
|
|
|
|
+ const uint8_t * pd0,
|
|
|
|
|
+ const uint8_t * pd1,
|
|
|
|
|
+ const uint8_t * pb0,
|
|
|
|
|
+ const uint8_t * pb1,
|
|
|
|
|
+ size_t bs,
|
|
|
|
|
+ int i
|
|
|
|
|
+) {
|
|
|
|
|
+ const float * d0_0 = (const float *) (pd0 + i*bs);
|
|
|
|
|
+ const float * d1_0 = (const float *) (pd1 + i*bs);
|
|
|
|
|
+
|
|
|
|
|
+ const uint8_t * restrict p0 = pb0 + (i+0)*bs;
|
|
|
|
|
+ const uint8_t * restrict p1 = pb1 + (i+0)*bs;
|
|
|
|
|
+
|
|
|
|
|
+ // Compute combined scale for the block
|
|
|
|
|
+ float scaleScalar = d0_0[0] * d1_0[0];
|
|
|
|
|
+ __m512 scale = _mm512_set1_ps( scaleScalar );
|
|
|
|
|
+
|
|
|
|
|
+ __m256i bx = bytesFromNibbles( p0 );
|
|
|
|
|
+ __m256i by = bytesFromNibbles( p1 );
|
|
|
|
|
+
|
|
|
|
|
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
|
|
|
|
+ const __m256i off = _mm256_set1_epi8( 8 );
|
|
|
|
|
+ bx = _mm256_sub_epi8( bx, off );
|
|
|
|
|
+ by = _mm256_sub_epi8( by, off );
|
|
|
|
|
+
|
|
|
|
|
+ // Sign-extend 16 signed bytes into int16_t
|
|
|
|
|
+ __m512i x32 = _mm512_cvtepi8_epi16( bx );
|
|
|
|
|
+ __m512i y32 = _mm512_cvtepi8_epi16( by );
|
|
|
|
|
+ // Compute products of int16_t integers, add pairwise
|
|
|
|
|
+ __m512i i64 = _mm512_madd_epi16( x32, y32 );
|
|
|
|
|
+
|
|
|
|
|
+ // Convert int32_t to float
|
|
|
|
|
+ __m512 p = _mm512_cvtepi32_ps( i64 );
|
|
|
|
|
+ // Apply the scale, and accumulate
|
|
|
|
|
+ return _mm512_fmadd_ps( scale, p, acc );
|
|
|
|
|
+}
|
|
|
|
|
+#endif
|
|
|
|
|
+
|
|
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
|
inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
|
|
ggml_float sumf = 0.0;
|
|
ggml_float sumf = 0.0;
|
|
|
|
|
|
|
@@ -1417,6 +1457,40 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|
|
#else
|
|
#else
|
|
|
#error "not implemented for QK"
|
|
#error "not implemented for QK"
|
|
|
#endif
|
|
#endif
|
|
|
|
|
+#elif defined(__AVX512F__)
|
|
|
|
|
+
|
|
|
|
|
+#if QK == 32
|
|
|
|
|
+ // Initialize accumulator with zeros
|
|
|
|
|
+ __m512 acc0 = _mm512_setzero_ps();
|
|
|
|
|
+ __m512 acc1 = _mm512_setzero_ps();
|
|
|
|
|
+
|
|
|
|
|
+ const int superblock_size = 8;
|
|
|
|
|
+ const int superblock_count = nb / superblock_size;
|
|
|
|
|
+ const int remainder = nb % superblock_size;
|
|
|
|
|
+
|
|
|
|
|
+ for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
|
|
|
|
|
+ int i = superblock_ix * superblock_size;
|
|
|
|
|
+
|
|
|
|
|
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+0 );
|
|
|
|
|
+ acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+1 );
|
|
|
|
|
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+2 );
|
|
|
|
|
+ acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+3 );
|
|
|
|
|
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+4 );
|
|
|
|
|
+ acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+5 );
|
|
|
|
|
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i+6 );
|
|
|
|
|
+ acc1 = dot_q4_0_oneblock_avx512( acc1, pd0, pd1, pb0, pb1, bs, i+7 );
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Remainders
|
|
|
|
|
+ for (int i = superblock_count * superblock_size; i < nb; ++i) {
|
|
|
|
|
+ acc0 = dot_q4_0_oneblock_avx512( acc0, pd0, pd1, pb0, pb1, bs, i );
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Horizontal sum of all lanes of the accumulator
|
|
|
|
|
+ sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
|
|
|
|
|
+#else
|
|
|
|
|
+#error "not implemented for QK"
|
|
|
|
|
+#endif
|
|
|
#elif defined(__AVX2__)
|
|
#elif defined(__AVX2__)
|
|
|
#if QK == 32
|
|
#if QK == 32
|
|
|
const size_t countBlocks = nb;
|
|
const size_t countBlocks = nb;
|
|
@@ -1928,7 +2002,7 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
|
|
|
const size_t bs = 2*sizeof(float) + QK/2;
|
|
const size_t bs = 2*sizeof(float) + QK/2;
|
|
|
|
|
|
|
|
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
|
|
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
|
|
|
- const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
|
|
|
|
|
|
|
+ const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
|
|
|
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
|
|
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
|
|
|
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
for (int i = 0; i < nb; i++) {
|