#pragma once #include "iqk_config.h" #if defined IQK_IMPLEMENT #include "ggml-impl.h" #if defined(__ARM_NEON) && defined(__aarch64__) // copy-pasted from Justine Tunney's contribution to llama.cpp // adapted from arm limited optimized routine // the maximum error is 1.45358 plus 0.5 ulps // numbers above 88.38 will flush to infinity // numbers beneath -103.97 will flush to zero static inline float32x4_t v_expf(float32x4_t x) { const float32x4_t r = vdupq_n_f32(0x1.8p23f); const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); const float32x4_t n = vsubq_f32(z, r); const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, vdupq_n_f32(0x1.7f7d1cp-20f)); const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); const float32x4_t u = vmulq_f32(b, b); const float32x4_t j = vfmaq_f32( vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); if (!vpaddd_u64(vreinterpretq_u64_u32(c))) return vfmaq_f32(k, j, k); const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); } static inline float16x8_t v_expf(float16x8_t x) { auto val1 = v_expf(vcvt_f32_f16(vget_low_f16(x))); auto val2 = v_expf(vcvt_f32_f16(vget_high_f16(x))); return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); } static inline float32x4_t v_tanh(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t two_x = vmulq_f32(x, vdupq_n_f32(2.f)); const float32x4_t exp_two_x = v_expf(two_x); const uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); const float32x4_t res = vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); return vreinterpretq_f32_u32(vorrq_u32(vandq_u32(vreinterpretq_u32_f32(one), mask), vbicq_u32(vreinterpretq_u32_f32(res), mask))); //return vdivq_f32(vsubq_f32(exp_two_x, one), vaddq_f32(exp_two_x, one)); } //inline float32x4_t v_tanh(float16x8_t x) { // auto val1 = v_tanh(vcvt_f32_f16(vget_low_f16(x))); // auto val2 = v_tanh(vcvt_f32_f16(vget_high_f16(x))); // return vcombine_f16(vcvt_f16_f32(val1), vcvt_f16_f32(val2)); //} static inline float32x4_t v_silu(float32x4_t x) { const float32x4_t one = vdupq_n_f32(1.0f); const float32x4_t zero = vdupq_n_f32(0.0f); const float32x4_t neg_x = vsubq_f32(zero, x); const float32x4_t exp_neg_x = v_expf(neg_x); const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); return vdivq_f32(x, one_plus_exp_neg_x); } static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) { const float32x4_t one = vdupq_n_f32(1.0f); float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x)); arg = vmulq_f32(arg, vmulq_f32(x, c2)); float32x4_t exp_arg = v_expf(arg); float32x4_t gelu = vmulq_f32(x, vdivq_f32(exp_arg, vaddq_f32(exp_arg, one))); uint32x4_t mask = vcgtq_f32(x, vdupq_n_f32(10.f)); return vbslq_f32(mask, x, gelu); } #endif // __ARN_NEON #if defined(__AVX512F__) && defined(__AVX512DQ__) // copy-pasted from Justine Tunney's contribution to llama.cpp // adapted from arm limited optimized routine // the maximum error is 1.45358 plus 0.5 ulps // numbers above 88.38 will flush to infinity // numbers beneath -103.97 will flush to zero static inline __m512 v_expf(__m512 x) { const __m512 r = _mm512_set1_ps(0x1.8p23f); const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); const __m512 n = _mm512_sub_ps(z, r); const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); const __mmask16 d = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); const __m512 u = _mm512_mul_ps(b, b); const __m512 j = _mm512_fmadd_ps( _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, _mm512_set1_ps(0x1.573e2ep-5f)), u, _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, _mm512_set1_ps(0x1.fffdb6p-2f))), u, _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); const __m512 res = _mm512_scalef_ps(j, n); if (_mm512_kortestz(d, d)) return res; const __m512 zero = _mm512_setzero_ps(); const __m512 alt = _mm512_mask_blend_ps( _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); return _mm512_mask_blend_ps(d, res, alt); } static inline __m512 v_tanh(__m512 x) { const __m512 one = _mm512_set1_ps(1.0f); const __m512 exp_two_x = v_expf(_mm512_mul_ps(x, _mm512_set1_ps(2.f))); const __mmask16 mask = _mm512_cmp_ps_mask(x, _mm512_set1_ps(10.f), _CMP_GT_OQ); const __m512 res = _mm512_div_ps(_mm512_sub_ps(exp_two_x, one), _mm512_add_ps(exp_two_x, one)); return _mm512_mask_blend_ps(mask, res, one); } static inline __m512 v_gelu(__m512 x, __m512 c1, __m512 c2) { const __m512 one = _mm512_set1_ps(1.0f); __m512 arg = _mm512_fmadd_ps(x, _mm512_mul_ps(c1, x), one); //__m512 arg = _mm512_add_ps(one, _mm512_mul_ps(_mm512_mul_ps(x, x), c1)); arg = _mm512_mul_ps(arg, _mm512_mul_ps(c2, x)); const __mmask16 mask = _mm512_cmp_ps_mask(arg, _mm512_set1_ps(30.f), _CMP_GT_OQ); const __m512 exp_arg = v_expf(arg); const __m512 ratio = _mm512_div_ps(exp_arg, _mm512_add_ps(exp_arg, one)); return _mm512_mul_ps(x, _mm512_mask_blend_ps(mask, ratio, one)); } static inline __m512 v_silu(__m512 x) { const __m512 one = _mm512_set1_ps(1); const __m512 zero = _mm512_setzero_ps(); const __m512 neg_x = _mm512_sub_ps(zero, x); const __m512 exp_neg_x = v_expf(neg_x); const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); return _mm512_div_ps(x, one_plus_exp_neg_x); } #endif // __AVX512__ #if defined(__AVX2__) && defined(__FMA__) // adapted from arm limited optimized routine // the maximum error is 1.45358 plus 0.5 ulps // numbers above 88.38 will flush to infinity // numbers beneath -103.97 will flush to zero static inline __m256 v_expf(__m256 x) { const __m256 r = _mm256_set1_ps(0x1.8p23f); const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); const __m256 n = _mm256_sub_ps(z, r); const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); const __m256 k = _mm256_castsi256_ps( _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); const __m256i c = _mm256_castps_si256( _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), _mm256_set1_ps(126), _CMP_GT_OQ)); const __m256 u = _mm256_mul_ps(b, b); const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, _mm256_set1_ps(0x1.573e2ep-5f)), u, _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, _mm256_set1_ps(0x1.fffdb6p-2f))), u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) return _mm256_fmadd_ps(j, k, k); const __m256i g = _mm256_and_si256( _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), _mm256_set1_epi32(0x82000000u)); const __m256 s1 = _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); const __m256i d = _mm256_castps_si256( _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), _mm256_set1_ps(192), _CMP_GT_OQ)); return _mm256_or_ps( _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), _mm256_andnot_ps( _mm256_castsi256_ps(d), _mm256_or_ps( _mm256_and_ps(_mm256_castsi256_ps(c), _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); } static inline __m256 v_tanh(__m256 x) { const __m256 one = _mm256_set1_ps(1.0f); const __m256 exp_two_x = v_expf(_mm256_mul_ps(x, _mm256_set1_ps(2.f))); const __m256 res = _mm256_div_ps(_mm256_sub_ps(exp_two_x, one), _mm256_add_ps(exp_two_x, one)); const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); return _mm256_or_ps(_mm256_and_ps(mask, one), _mm256_andnot_ps(mask, res)); } static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) { const __m256 one = _mm256_set1_ps(1.0f); const __m256 mask = _mm256_cmp_ps(x, _mm256_set1_ps(10.f), _CMP_GT_OQ); __m256 arg = _mm256_add_ps(one, _mm256_mul_ps(_mm256_mul_ps(x, x), c1)); arg = _mm256_mul_ps(arg, _mm256_mul_ps(x, c2)); __m256 exp_arg = v_expf(arg); __m256 gelu = _mm256_mul_ps(x, _mm256_div_ps(exp_arg, _mm256_add_ps(exp_arg, one))); return _mm256_or_ps(_mm256_and_ps(mask, x), _mm256_andnot_ps(mask, gelu)); } static inline __m256 v_silu(__m256 x) { const __m256 one = _mm256_set1_ps(1); const __m256 zero = _mm256_setzero_ps(); const __m256 neg_x = _mm256_sub_ps(zero, x); const __m256 exp_neg_x = v_expf(neg_x); const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); return _mm256_div_ps(x, one_plus_exp_neg_x); } #endif // __AVX2__ #endif // IQK_IMPLEMENT