Просмотр исходного кода

Optimize AVX2 ggml_vec_dot_q4_0 (#642)

slaren 2 лет назад
Родитель
Сommit
1d08882afa
1 измененных файлов с 18 добавлено и 13 удалено
  1. 18 13
      ggml.c

+ 18 - 13
ggml.c

@@ -1833,7 +1833,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
     const block_q4_0 * restrict x = vx;
     const block_q4_0 * restrict x = vx;
     const block_q4_0 * restrict y = vy;
     const block_q4_0 * restrict y = vy;
 
 
-    ggml_float sumf = 0.0;
+    float sumf = 0.0;
 
 
 #if defined(__ARM_NEON)
 #if defined(__ARM_NEON)
     float sum0 = 0.0f;
     float sum0 = 0.0f;
@@ -1928,7 +1928,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
 #endif
 #endif
     }
     }
 
 
-    sumf = (ggml_float)(sum0 + sum1);
+    sumf = sum0 + sum1;
 #elif defined(__AVX512F__)
 #elif defined(__AVX512F__)
     // Initialize accumulator with zeros
     // Initialize accumulator with zeros
     __m512 acc0 = _mm512_setzero_ps();
     __m512 acc0 = _mm512_setzero_ps();
@@ -1962,6 +1962,10 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
     __m256 acc = _mm256_setzero_ps();
     __m256 acc = _mm256_setzero_ps();
 
 
     // Main loop
     // Main loop
+    // TODO: figure a way to do this in a portable way
+    #ifdef __GNUC__
+    #pragma GCC unroll 16
+    #endif
     for (int i = 0; i < nb; ++i) {
     for (int i = 0; i < nb; ++i) {
         // Compute combined scale for the block
         // Compute combined scale for the block
         const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
         const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
@@ -1975,20 +1979,21 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
         bx = _mm256_sub_epi8( bx, off );
         bx = _mm256_sub_epi8( bx, off );
         by = _mm256_sub_epi8( by, off );
         by = _mm256_sub_epi8( by, off );
 
 
-        // Sign-extend first 16 signed bytes into int16_t
-        __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
-        __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
-        // Compute products of int16_t integers, add pairwise
-        __m256i i32 = _mm256_madd_epi16( x16, y16 );
+        // Get absolute values of x vectors
+        const __m256i ax = _mm256_sign_epi8(bx, bx);
 
 
-        // Sign-extend last 16 signed bytes into int16_t vectors
-        x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
-        y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
-        // Accumulate products of int16_t integers
-        i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
+        // Sign the values of the y vectors
+        const __m256i sy = _mm256_sign_epi8(by, bx);
+
+        // Perform multiplication and create 16-bit values
+        const __m256i dot = _mm256_maddubs_epi16(ax, sy);
+
+        const __m256i ones = _mm256_set1_epi16(1);
+        const __m256i i32 = _mm256_madd_epi16(ones, dot);
 
 
         // Convert int32_t to float
         // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps( i32 );
+        const __m256 p = _mm256_cvtepi32_ps( i32 );
+
         // Apply the scale, and accumulate
         // Apply the scale, and accumulate
         acc = _mm256_fmadd_ps( d, p, acc );
         acc = _mm256_fmadd_ps( d, p, acc );
     }
     }