diff options
author | Chris Elrod <elrodc@gmail.com> | 2024-05-30 07:32:55 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-30 21:32:55 +1000 |
commit | 59b0d077662fab430446b3119fa142f3291c45b2 (patch) | |
tree | 8b95e7bd69048be614a85660667af5ac38436914 /ggml.c | |
parent | d5c05821f3c3d6cabe8ac45776fe0ecb0da13eca (diff) |
faster avx512 exp implementation (#7551)
* faster avx512 exp implementation
* x->r
* improve accuracy, handle special cases
* remove `e`
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 43 |
1 files changed, 19 insertions, 24 deletions
@@ -2315,32 +2315,27 @@ inline static __m512 ggml_v_expf(__m512 x) { const __m512 r = _mm512_set1_ps(0x1.8p23f); const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); const __m512 n = _mm512_sub_ps(z, r); - const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), - _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); - const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23); - const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1)))); - const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ); - const __m512 u = _mm512_mul_ps(b, b); - const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, - _mm512_set1_ps(0x1.573e2ep-5f)), u, - _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, - _mm512_set1_ps(0x1.fffdb6p-2f))), - u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b)); - if (_mm512_kortestz(c, c)) - return _mm512_fmadd_ps(j, k, k); - const __m512i g = _mm512_and_si512( - _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)), - _mm512_set1_epi32(0x82000000u)); - const __m512 s1 = - _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u))); - const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g)); + const __m512 b = + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), + _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); const __mmask16 d = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); - return _mm512_mask_blend_ps( - d, _mm512_mask_blend_ps( - c, _mm512_fmadd_ps(k, j, k), - _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)), - _mm512_mul_ps(s1, s1)); + const __m512 u = _mm512_mul_ps(b, b); + const __m512 j = _mm512_fmadd_ps( + _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, + _mm512_set1_ps(0x1.573e2ep-5f)), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, + _mm512_set1_ps(0x1.fffdb6p-2f))), + u, + _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); + const __m512 res = _mm512_scalef_ps(j, n); + if (_mm512_kortestz(d, d)) + return res; + const __m512 zero = _mm512_setzero_ps(); + const __m512 alt = _mm512_mask_blend_ps( + _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); + return _mm512_mask_blend_ps(d, res, alt); } // computes silu x/(1+exp(-x)) in single precision vector |