summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSrihari-mcw <96763064+Srihari-mcw@users.noreply.github.com>2024-05-19 19:18:39 -0700
committerGitHub <noreply@github.com>2024-05-20 12:18:39 +1000
commit33c8d50accd6dca73c9c4af00a05e24209c160fe (patch)
tree926d8b9ad683420872afb234f5cd94dba3a3c500
parentd359f30921a9f62a0fd299c412ff3f270286fea6 (diff)
Add provisions for windows support for BF16 code including CMake provision for enabling AVX512_BF16 (#7258)
-rw-r--r--CMakeLists.txt8
-rw-r--r--ggml-impl.h12
-rw-r--r--ggml.c24
-rw-r--r--ggml.h1
-rw-r--r--llama.cpp1
5 files changed, 38 insertions, 8 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 616698c7..92c9f09e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -77,6 +77,7 @@ option(LLAMA_AVX2 "llama: enable AVX2"
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
+option(LLAMA_AVX512_BF16 "llama: enable AVX512-BF16" OFF)
option(LLAMA_FMA "llama: enable FMA" ${INS_ENB})
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
@@ -1060,6 +1061,10 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
+ if (LLAMA_AVX512_BF16)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
+ endif()
elseif (LLAMA_AVX2)
list(APPEND ARCH_FLAGS /arch:AVX2)
elseif (LLAMA_AVX)
@@ -1091,6 +1096,9 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
if (LLAMA_AVX512_VNNI)
list(APPEND ARCH_FLAGS -mavx512vnni)
endif()
+ if (LLAMA_AVX512_BF16)
+ list(APPEND ARCH_FLAGS -mavx512bf16)
+ endif()
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
message(STATUS "PowerPC detected")
diff --git a/ggml-impl.h b/ggml-impl.h
index 59684fa8..5ff014fe 100644
--- a/ggml-impl.h
+++ b/ggml-impl.h
@@ -17,6 +17,18 @@
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#if defined(_WIN32)
+
+#define m512bh(p) p
+#define m512i(p) p
+
+#else
+
+#define m512bh(p) (__m512bh)(p)
+#define m512i(p) (__m512i)(p)
+
+#endif
+
/**
* Converts brain16 to float32.
*
diff --git a/ggml.c b/ggml.c
index 3a104c48..53da231e 100644
--- a/ggml.c
+++ b/ggml.c
@@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
int i = 0;
#if defined(__AVX512BF16__)
for (; i + 32 <= n; i += 32) {
- _mm512_storeu_ps(
- (__m512 *)(y + i),
- (__m512)_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
- _mm512_loadu_ps(x + i)));
+ _mm512_storeu_si512(
+ (__m512i *)(y + i),
+ m512i(_mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
+ _mm512_loadu_ps(x + i))));
}
#endif
for (; i < n; i++) {
@@ -1666,10 +1666,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
__m512 c1 = _mm512_setzero_ps();
__m512 c2 = _mm512_setzero_ps();
for (; i + 64 <= n; i += 64) {
- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)(x + i)),
- (__m512bh)_mm512_loadu_ps((const float *)(y + i)));
- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)(x + i + 32)),
- (__m512bh)_mm512_loadu_ps((const float *)(y + i + 32)));
+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))),
+ m512bh(_mm512_loadu_si512((y + i))));
+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))),
+ m512bh(_mm512_loadu_si512((y + i + 32))));
}
sumf += (ggml_float)_mm512_reduce_add_ps(c1);
sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -23137,6 +23137,14 @@ int ggml_cpu_has_avx512_vnni(void) {
#endif
}
+int ggml_cpu_has_avx512_bf16(void) {
+#if defined(__AVX512BF16__)
+ return 1;
+#else
+ return 0;
+#endif
+}
+
int ggml_cpu_has_fma(void) {
#if defined(__FMA__)
return 1;
diff --git a/ggml.h b/ggml.h
index 8c13f4ba..77475710 100644
--- a/ggml.h
+++ b/ggml.h
@@ -2390,6 +2390,7 @@ extern "C" {
GGML_API int ggml_cpu_has_avx512 (void);
GGML_API int ggml_cpu_has_avx512_vbmi(void);
GGML_API int ggml_cpu_has_avx512_vnni(void);
+ GGML_API int ggml_cpu_has_avx512_bf16(void);
GGML_API int ggml_cpu_has_fma (void);
GGML_API int ggml_cpu_has_neon (void);
GGML_API int ggml_cpu_has_arm_fma (void);
diff --git a/llama.cpp b/llama.cpp
index 102bc202..ca3e9fcc 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -18074,6 +18074,7 @@ const char * llama_print_system_info(void) {
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
+ s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";