summaryrefslogtreecommitdiff
path: root/ggml/src/iqk/iqk_utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/iqk/iqk_utils.h')
-rw-r--r--ggml/src/iqk/iqk_utils.h207
1 files changed, 207 insertions, 0 deletions
diff --git a/ggml/src/iqk/iqk_utils.h b/ggml/src/iqk/iqk_utils.h
new file mode 100644
index 00000000..194bf9b8
--- /dev/null
+++ b/ggml/src/iqk/iqk_utils.h
@@ -0,0 +1,207 @@
+#pragma once
+
+#include "iqk_config.h"
+
+#if defined IQK_IMPLEMENT
+
+#include "ggml-impl.h"
+
+#if defined(__ARM_NEON) && defined(__aarch64__)
+// copy-pasted from Justine Tunney's contribution to llama.cpp
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+static inline float32x4_t v_expf(float32x4_t x) {
+ const float32x4_t r = vdupq_n_f32(0x1.8p23f);
+ const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
+ const float32x4_t n = vsubq_f32(z, r);
+ const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
+ vdupq_n_f32(0x1.7f7d1cp-20f));
+ const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
+ const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
+ const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
+ const float32x4_t u = vmulq_f32(b, b);
+ const float32x4_t j = vfmaq_f32(
+ vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
+ vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
+ vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
+ if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
+ return vfmaq_f32(k, j, k);
+ const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
+ const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
+ const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
+ return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
+ vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
+}
+static inline float16x8_t v_expf(float16x8_t x) {
+ auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x)));
+ auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x)));
+ return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
+}
+static inline float32x4_t v_tanh(float32x4_t x) {
+ const float32x4_t one = vdupq_n_f32(1.0f);
+ const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f));
+ const float32x4_t exp_two_x = v_expf(two_x);
+ const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
+ const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
+ return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask)));
+ //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one));
+}
+//inline float32x4_t v_tanh(float16x8_t x) {
+// auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x)));
+// auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x)));
+// return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2));
+//}
+static inline float32x4_t v_silu(float32x4_t x) {
+ const float32x4_t one = vdupq_n_f32(1.0f);
+ const float32x4_t zero = vdupq_n_f32(0.0f);
+ const float32x4_t neg_x = vsubq_f32(zero, x);
+ const float32x4_t exp_neg_x = v_expf(neg_x);
+ const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
+ return vdivq_f32(x, one_plus_exp_neg_x);
+}
+static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
+ const float32x4_t one = vdupq_n_f32(1.0f);
+ float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
+ arg = vmulq_f32(arg, vmulq_f32(x, c2));
+ float32x4_t exp_arg = v_expf(arg);
+ float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one)));
+ uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f));
+ return vbslq_f32(mask, x, gelu);
+}
+
+#endif // __ARN_NEON
+
+#if defined(__AVX512F__) && defined(__AVX512DQ__)
+
+// copy-pasted from Justine Tunney's contribution to llama.cpp
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+static inline __m512 v_expf(__m512 x) {
+ const __m512 r = _mm512_set1_ps(0x1.8p23f);
+ const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
+ const __m512 n = _mm512_sub_ps(z, r);
+ const __m512 b =
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
+ _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
+ const __mmask16 d =
+ _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
+ const __m512 u = _mm512_mul_ps(b, b);
+ const __m512 j = _mm512_fmadd_ps(
+ _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
+ _mm512_set1_ps(0x1.573e2ep-5f)),
+ u,
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
+ _mm512_set1_ps(0x1.fffdb6p-2f))),
+ u,
+ _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
+ const __m512 res = _mm512_scalef_ps(j, n);
+ if (_mm512_kortestz(d, d))
+ return res;
+ const __m512 zero = _mm512_setzero_ps();
+ const __m512 alt = _mm512_mask_blend_ps(
+ _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
+ return _mm512_mask_blend_ps(d, res, alt);
+}
+static inline __m512 v_tanh(__m512 x) {
+ const __m512 one = _mm512_set1_ps(1.0f);
+ const __m512 exp_two_x = v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f)));
+ const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ);
+ const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one));
+ return _mm512_mask_blend_ps(mask, res, one);
+}
+static inline __m512 v_gelu(__m512 x, __m512 c1, __m512 c2) {
+ const __m512 one = _mm512_set1_ps(1.0f);
+ __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one);
+ //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1));
+ arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x));
+ const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ);
+ const __m512 exp_arg = v_expf(arg);
+ const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one));
+ return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one));
+}
+static inline __m512 v_silu(__m512 x) {
+ const __m512 one = _mm512_set1_ps(1);
+ const __m512 zero = _mm512_setzero_ps();
+ const __m512 neg_x = _mm512_sub_ps(zero, x);
+ const __m512 exp_neg_x = v_expf(neg_x);
+ const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
+ return _mm512_div_ps(x, one_plus_exp_neg_x);
+}
+#endif // __AVX512__
+
+#if defined(__AVX2__) && defined(__FMA__)
+
+// adapted from arm limited optimized routine
+// the maximum error is 1.45358 plus 0.5 ulps
+// numbers above 88.38 will flush to infinity
+// numbers beneath -103.97 will flush to zero
+static inline __m256 v_expf(__m256 x) {
+ const __m256 r = _mm256_set1_ps(0x1.8p23f);
+ const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
+ const __m256 n = _mm256_sub_ps(z, r);
+ const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
+ _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
+ const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
+ const __m256 k = _mm256_castsi256_ps(
+ _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
+ const __m256i c = _mm256_castps_si256(
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+ _mm256_set1_ps(126), _CMP_GT_OQ));
+ const __m256 u = _mm256_mul_ps(b, b);
+ const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
+ _mm256_set1_ps(0x1.573e2ep-5f)), u,
+ _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
+ _mm256_set1_ps(0x1.fffdb6p-2f))),
+ u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
+ if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
+ return _mm256_fmadd_ps(j, k, k);
+ const __m256i g = _mm256_and_si256(
+ _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
+ _mm256_set1_epi32(0x82000000u));
+ const __m256 s1 =
+ _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
+ const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
+ const __m256i d = _mm256_castps_si256(
+ _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
+ _mm256_set1_ps(192), _CMP_GT_OQ));
+ return _mm256_or_ps(
+ _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
+ _mm256_andnot_ps(
+ _mm256_castsi256_ps(d),
+ _mm256_or_ps(
+ _mm256_and_ps(_mm256_castsi256_ps(c),
+ _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
+ _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
+}
+static inline __m256 v_tanh(__m256 x) {
+ const __m256 one = _mm256_set1_ps(1.0f);
+ const __m256 exp_two_x = v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f)));
+ const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one));
+ const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
+ return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res));
+}
+static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
+ const __m256 one = _mm256_set1_ps(1.0f);
+ const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ);
+ __m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1));
+ arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2));
+ __m256 exp_arg = v_expf(arg);
+ __m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one)));
+ return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu));
+}
+static inline __m256 v_silu(__m256 x) {
+ const __m256 one = _mm256_set1_ps(1);
+ const __m256 zero = _mm256_setzero_ps();
+ const __m256 neg_x = _mm256_sub_ps(zero, x);
+ const __m256 exp_neg_x = v_expf(neg_x);
+ const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
+ return _mm256_div_ps(x, one_plus_exp_neg_x);
+}
+
+#endif // __AVX2__
+
+#endif // IQK_IMPLEMENT