|
|
@@ -103,10 +103,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
|
|
}
|
|
|
|
|
|
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
|
|
-#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
|
|
+#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
|
|
const __m256i zero = _mm256_setzero_si256();
|
|
|
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
|
|
return _mm256_cvtepi32_ps(summed_pairs);
|
|
|
+#elif defined(__AVXVNNI__)
|
|
|
+ const __m256i zero = _mm256_setzero_si256();
|
|
|
+ const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
|
|
+ return _mm256_cvtepi32_ps(summed_pairs);
|
|
|
#else
|
|
|
// Perform multiplication and create 16-bit values
|
|
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|