|
@@ -607,10 +607,11 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
|
|
|
assert(k % QK == 0);
|
|
assert(k % QK == 0);
|
|
|
|
|
|
|
|
const int nb = k / QK;
|
|
const int nb = k / QK;
|
|
|
|
|
+ const size_t bs = 2*sizeof(float) + QK/2;
|
|
|
|
|
|
|
|
- float * restrict pm = (float *) (y);
|
|
|
|
|
- float * restrict pd = (float *) (pm + nb);
|
|
|
|
|
- uint8_t * restrict pb = (uint8_t *) (pd + nb);
|
|
|
|
|
|
|
+ uint8_t * restrict pd = ((uint8_t *)y + 0*bs);
|
|
|
|
|
+ uint8_t * restrict pm = ((uint8_t *)y + 0*bs + sizeof(float));
|
|
|
|
|
+ uint8_t * restrict pb = ((uint8_t *)y + 0*bs + 2*sizeof(float));
|
|
|
|
|
|
|
|
uint8_t pp[QK/2];
|
|
uint8_t pp[QK/2];
|
|
|
|
|
|
|
@@ -627,8 +628,10 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
|
|
|
const float d = (max - min) / ((1 << 4) - 1);
|
|
const float d = (max - min) / ((1 << 4) - 1);
|
|
|
const float id = d ? 1.0f/d : 0.0f;
|
|
const float id = d ? 1.0f/d : 0.0f;
|
|
|
|
|
|
|
|
- pm[i] = min;
|
|
|
|
|
- pd[i] = d;
|
|
|
|
|
|
|
+ *(float *)pm = min;
|
|
|
|
|
+ *(float *)pd = d;
|
|
|
|
|
+ pm += bs;
|
|
|
|
|
+ pd += bs;
|
|
|
|
|
|
|
|
for (int l = 0; l < QK; l += 2) {
|
|
for (int l = 0; l < QK; l += 2) {
|
|
|
const float v0 = (x[i*QK + l + 0] - min)*id;
|
|
const float v0 = (x[i*QK + l + 0] - min)*id;
|
|
@@ -643,7 +646,8 @@ void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
|
|
|
pp[l/2] = vi0 | (vi1 << 4);
|
|
pp[l/2] = vi0 | (vi1 << 4);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- memcpy(pb + i*QK/2, pp, sizeof(pp));
|
|
|
|
|
|
|
+ memcpy(pb, pp, sizeof(pp));
|
|
|
|
|
+ pb += bs;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -687,16 +691,17 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
|
|
|
assert(k % QK == 0);
|
|
assert(k % QK == 0);
|
|
|
|
|
|
|
|
const int nb = k / QK;
|
|
const int nb = k / QK;
|
|
|
|
|
+ const size_t bs = 2*sizeof(float) + QK/2;
|
|
|
|
|
|
|
|
- const float * restrict pm = (const float *) (x);
|
|
|
|
|
- const float * restrict pd = (const float *) (pm + nb);
|
|
|
|
|
- const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
|
|
|
|
|
|
|
+ const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
|
|
|
|
|
+ const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
|
|
|
|
|
+ const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
|
|
|
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
for (int i = 0; i < nb; i++) {
|
|
|
- const float m = pm[i];
|
|
|
|
|
- const float d = pd[i];
|
|
|
|
|
|
|
+ const float d = *(const float *) (pd + i*bs);
|
|
|
|
|
+ const float m = *(const float *) (pm + i*bs);
|
|
|
|
|
|
|
|
- const uint8_t * restrict pp = pb + i*QK/2;
|
|
|
|
|
|
|
+ const uint8_t * restrict pp = pb + i*bs;
|
|
|
|
|
|
|
|
for (int l = 0; l < QK; l += 2) {
|
|
for (int l = 0; l < QK; l += 2) {
|
|
|
const uint8_t vi = pp[l/2];
|
|
const uint8_t vi = pp[l/2];
|
|
@@ -1584,28 +1589,109 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
|
|
|
inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
|
|
inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict x, const void * restrict y) {
|
|
|
const int nb = n / QK;
|
|
const int nb = n / QK;
|
|
|
|
|
|
|
|
- const float * restrict pm0 = (const float *) x;
|
|
|
|
|
- const float * restrict pm1 = (const float *) y;
|
|
|
|
|
|
|
+ const size_t bs = 2*sizeof(float) + QK/2;
|
|
|
|
|
|
|
|
- const float * restrict pd0 = (const float *) (pm0 + nb);
|
|
|
|
|
- const float * restrict pd1 = (const float *) (pm1 + nb);
|
|
|
|
|
|
|
+ const uint8_t * restrict pd0 = ((const uint8_t *)x + 0*bs);
|
|
|
|
|
+ const uint8_t * restrict pd1 = ((const uint8_t *)y + 0*bs);
|
|
|
|
|
+
|
|
|
|
|
+ const uint8_t * restrict pm0 = ((const uint8_t *)x + 0*bs + sizeof(float));
|
|
|
|
|
+ const uint8_t * restrict pm1 = ((const uint8_t *)y + 0*bs + sizeof(float));
|
|
|
|
|
|
|
|
- const uint8_t * restrict pb0 = (const uint8_t *) (pd0 + nb);
|
|
|
|
|
- const uint8_t * restrict pb1 = (const uint8_t *) (pd1 + nb);
|
|
|
|
|
|
|
+ const uint8_t * restrict pb0 = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
|
|
|
|
|
+ const uint8_t * restrict pb1 = ((const uint8_t *)y + 0*bs + 2*sizeof(float));
|
|
|
|
|
|
|
|
float sumf = 0.0;
|
|
float sumf = 0.0;
|
|
|
|
|
|
|
|
-#if 1
|
|
|
|
|
|
|
+#if defined(__AVX2__)
|
|
|
|
|
+#if QK == 32
|
|
|
|
|
+ // Initialize accumulator with zeros
|
|
|
|
|
+ __m256 acc = _mm256_setzero_ps();
|
|
|
|
|
+ // Accumulator for constant offsets
|
|
|
|
|
+ float acc_offset = 0.0f;
|
|
|
|
|
+
|
|
|
|
|
+ // Main loop
|
|
|
|
|
+ for (int i = 0; i < nb; ++i) {
|
|
|
|
|
+ const float * m0 = (const float *) (pm0 + i*bs);
|
|
|
|
|
+ const float * m1 = (const float *) (pm1 + i*bs);
|
|
|
|
|
+
|
|
|
|
|
+ const float * d0 = (const float *) (pd0 + i*bs);
|
|
|
|
|
+ const float * d1 = (const float *) (pd1 + i*bs);
|
|
|
|
|
+
|
|
|
|
|
+ const uint8_t * restrict p0 = pb0 + i*bs;
|
|
|
|
|
+ const uint8_t * restrict p1 = pb1 + i*bs;
|
|
|
|
|
+
|
|
|
|
|
+ const __m256 d0v = _mm256_broadcast_ss( d0 );
|
|
|
|
|
+ const __m256 d1v = _mm256_broadcast_ss( d1 );
|
|
|
|
|
+ const __m256 m0v = _mm256_broadcast_ss( m0 );
|
|
|
|
|
+ const __m256 m1v = _mm256_broadcast_ss( m1 );
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ // Compute combined scale for the block
|
|
|
|
|
+ const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
|
|
|
|
|
+
|
|
|
|
|
+ // Compute cross scales for the block
|
|
|
|
|
+ const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
|
|
|
|
|
+ const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
|
|
|
|
|
+ const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0b10101010 );
|
|
|
|
|
+
|
|
|
|
|
+ // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
|
|
|
|
|
+ __m256i bx = bytesFromNibbles( p0 );
|
|
|
|
|
+ __m256i by = bytesFromNibbles( p1 );
|
|
|
|
|
+
|
|
|
|
|
+ // Now we have a vector with bytes in [ 0 .. 15 ] interval.
|
|
|
|
|
+
|
|
|
|
|
+ // Sign-extend first 16 signed bytes into int16_t
|
|
|
|
|
+ __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
|
|
|
|
|
+ __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
|
|
|
|
|
+ // Compute products of int16_t integers, add pairwise
|
|
|
|
|
+ __m256i i32 = _mm256_madd_epi16( x16, y16 );
|
|
|
|
|
+
|
|
|
|
|
+ // Sign-extend last 16 signed bytes into int16_t vectors
|
|
|
|
|
+ __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
|
|
|
|
|
+ __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
|
|
|
|
|
+ // Accumulate products of int16_t integers
|
|
|
|
|
+ i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
|
|
|
|
|
+
|
|
|
|
|
+ // compute sums of unsigned bytes in bx, by in blocks of 8.
|
|
|
|
|
+ // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
|
|
|
|
|
+ // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
|
|
|
|
|
+ // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
|
|
|
|
|
+ __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
|
|
|
|
|
+ __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
|
|
|
|
|
+ __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
|
|
|
|
|
+ __m256 sums = _mm256_cvtepi32_ps( sumsi );
|
|
|
|
|
+
|
|
|
|
|
+ // Convert int32_t to float
|
|
|
|
|
+ __m256 p = _mm256_cvtepi32_ps( i32 );
|
|
|
|
|
+ // Apply the scale, and accumulate
|
|
|
|
|
+ // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
|
|
|
|
|
+ acc = _mm256_fmadd_ps( scale_01, p, acc );
|
|
|
|
|
+ acc = _mm256_fmadd_ps( cross_scales, sums, acc );
|
|
|
|
|
+ // acc_offset += m0*m1 (for each entry in the block)
|
|
|
|
|
+ acc_offset += (*m0)*(*m1);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Return horizontal sum of the acc vector
|
|
|
|
|
+ __m128 res = _mm256_extractf128_ps( acc, 1 );
|
|
|
|
|
+ res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
|
|
|
|
|
+ res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
|
|
|
|
+ res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
|
|
|
|
+
|
|
|
|
|
+ sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
|
|
|
|
|
+#else
|
|
|
|
|
+#error "not implemented for QK"
|
|
|
|
|
+#endif
|
|
|
|
|
+#else
|
|
|
// scalar
|
|
// scalar
|
|
|
for (int i = 0; i < nb; i++) {
|
|
for (int i = 0; i < nb; i++) {
|
|
|
- const float m0 = pm0[i];
|
|
|
|
|
- const float m1 = pm1[i];
|
|
|
|
|
|
|
+ const float m0 = *(const float *) (pm0 + i*bs);
|
|
|
|
|
+ const float m1 = *(const float *) (pm1 + i*bs);
|
|
|
|
|
|
|
|
- const float d0 = pd0[i];
|
|
|
|
|
- const float d1 = pd1[i];
|
|
|
|
|
|
|
+ const float d0 = *(const float *) (pd0 + i*bs);
|
|
|
|
|
+ const float d1 = *(const float *) (pd1 + i*bs);
|
|
|
|
|
|
|
|
- const uint8_t * restrict p0 = pb0 + i*QK/2;
|
|
|
|
|
- const uint8_t * restrict p1 = pb1 + i*QK/2;
|
|
|
|
|
|
|
+ const uint8_t * restrict p0 = pb0 + i*bs;
|
|
|
|
|
+ const uint8_t * restrict p1 = pb1 + i*bs;
|
|
|
|
|
|
|
|
for (int j = 0; j < QK/2; j++) {
|
|
for (int j = 0; j < QK/2; j++) {
|
|
|
const uint8_t v0 = p0[j];
|
|
const uint8_t v0 = p0[j];
|
|
@@ -1839,16 +1925,17 @@ inline static void ggml_vec_mad_q4_1(const int n, float * restrict y, void * res
|
|
|
assert(n % QK == 0);
|
|
assert(n % QK == 0);
|
|
|
|
|
|
|
|
const int nb = n / QK;
|
|
const int nb = n / QK;
|
|
|
|
|
+ const size_t bs = 2*sizeof(float) + QK/2;
|
|
|
|
|
|
|
|
- const float * restrict pm = (const float *) (x);
|
|
|
|
|
- const float * restrict pd = (const float *) (pm + nb);
|
|
|
|
|
- const uint8_t * restrict pb = (const uint8_t *) (pd + nb);
|
|
|
|
|
|
|
+ const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
|
|
|
|
|
+ const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
|
|
|
|
|
+ const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
|
|
|
|
|
|
|
|
for (int i = 0; i < nb; i++) {
|
|
for (int i = 0; i < nb; i++) {
|
|
|
- const float m = pm[i];
|
|
|
|
|
- const float d = pd[i];
|
|
|
|
|
|
|
+ const float d = *(const float *) (pd + i*bs);
|
|
|
|
|
+ const float m = *(const float *) (pm + i*bs);
|
|
|
|
|
|
|
|
- const uint8_t * restrict pp = pb + i*QK/2;
|
|
|
|
|
|
|
+ const uint8_t * restrict pp = pb + i*bs;
|
|
|
|
|
|
|
|
for (int l = 0; l < QK; l += 2) {
|
|
for (int l = 0; l < QK; l += 2) {
|
|
|
const uint8_t vi = pp[l/2];
|
|
const uint8_t vi = pp[l/2];
|