diff options
Diffstat (limited to 'ggml.c')
-rw-r--r-- | ggml.c | 24 |
1 files changed, 16 insertions, 8 deletions
@@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) { int i = 0; #if defined(__AVX512BF16__) for (; i + 32 <= n; i += 32) { - _mm512_storeu_ps( - (__m512 *)(y + i), - (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16), - _mm512_loadu_ps(x + i))); + _mm512_storeu_si512( + (__m512i *)(y + i), + m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16), + _mm512_loadu_ps(x + i)))); } #endif for (; i < n; i++) { @@ -1666,10 +1666,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t __m512 c1 = _mm512_setzero_ps(); __m512 c2 = _mm512_setzero_ps(); for (; i + 64 <= n; i += 64) { - c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)), - (__m512bh)_mm512_loadu_ps((const float *)(y + i))); - c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)), - (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32))); + c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))), + m512bh(_mm512_loadu_si512((y + i)))); + c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))), + m512bh(_mm512_loadu_si512((y + i + 32)))); } sumf += (ggml_float)_mm512_reduce_add_ps(c1); sumf += (ggml_float)_mm512_reduce_add_ps(c2); @@ -23137,6 +23137,14 @@ int ggml_cpu_has_avx512_vnni(void) { #endif } +int ggml_cpu_has_avx512_bf16(void) { +#if defined(__AVX512BF16__) + return 1; +#else + return 0; +#endif +} + int ggml_cpu_has_fma(void) { #if defined(__FMA__) return 1; |