summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_common.h
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_common.h')
-rw-r--r--ggml/src/iqk/iqk_common.h21
1 files changed, 18 insertions, 3 deletions
diff --git a/ggml/src/iqk/iqk_common.h b/ggml/src/iqk/iqk_common.h
index 6feeff1a..cce040dd 100644
--- a/ggml/src/iqk/iqk_common.h
+++ b/ggml/src/iqk/iqk_common.h
@@ -225,9 +225,15 @@ static inline int hsum_i32_8(const __m256i a) {
const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128);
const __m128i sum64 = _mm_add_epi32(hi64, sum128);
- const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
+ const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1));
return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32));
}
+static inline float hmax_f32_8(__m256 x) {
+ __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
+ max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4));
+ max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4));
+ return _mm_cvtss_f32(max4);
+}
static inline float hmax_float_8(__m256 x) {
__m128 max4 = _mm_max_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
max4 = _mm_max_ps( max4, _mm_movehl_ps(max4, max4));
@@ -235,13 +241,22 @@ static inline float hmax_float_8(__m256 x) {
return _mm_cvtss_f32(max4);
}
+static inline __m128 hsum_float_4x4(__m128 * accm) {
+ accm[0] = _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[2]), _mm_unpackhi_ps(accm[0], accm[2]));
+ accm[1] = _mm_add_ps(_mm_unpacklo_ps(accm[1], accm[3]), _mm_unpackhi_ps(accm[1], accm[3]));
+ return _mm_add_ps(_mm_unpacklo_ps(accm[0], accm[1]), _mm_unpackhi_ps(accm[0], accm[1]));
+}
static inline __m256 hsum_float_8x8(__m256 * accm) {
for (int i = 0; i < 4; ++i) {
- accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i+4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i+4], 0x31));
+ accm[i] = _mm256_add_ps(_mm256_permute2f128_ps(accm[i], accm[i + 4], 0x20), _mm256_permute2f128_ps(accm[i], accm[i + 4], 0x31));
//accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
// _mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
}
- for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i + 2]), _mm256_unpackhi_ps(accm[i], accm[i + 2]));
+ return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
+}
+static inline __m256 hsum_float_4x8(__m256 * accm) {
+ for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i + 2]), _mm256_unpackhi_ps(accm[i], accm[i + 2]));
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
}