summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
author0cc4m <picard12@live.de>2024-02-01 19:25:24 +0100
committerGitHub <noreply@github.com>2024-02-01 19:25:24 +0100
commit4d0924a8902010d31bd737b6f1f594943d120d0f (patch)
tree091227c1265488e6a528f280304b6ad92d6e8e17
parent8ca511cadee2c67f0bd8c7034a2513778ee9a1b7 (diff)
Vulkan Phi Fix for AMD Proprietary Drivers (#5260)
* Replace tanh to avoid NaN in gelu shader on AMD proprietary driver * Fix another Vulkan CPY buffer size bug
-rw-r--r--ggml-vulkan-shaders.hpp132
-rw-r--r--ggml-vulkan.cpp17
-rw-r--r--ggml_vk_generate_shaders.py3
3 files changed, 83 insertions, 69 deletions
diff --git a/ggml-vulkan-shaders.hpp b/ggml-vulkan-shaders.hpp
index e2e9be22..195410c0 100644
--- a/ggml-vulkan-shaders.hpp
+++ b/ggml-vulkan-shaders.hpp
@@ -14670,14 +14670,14 @@ const uint64_t f32_to_f16_fp32_len = 1596;
unsigned char gelu_f32_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
-0x45,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
+0x4b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x11,0x00,0x02,0x00,
0x01,0x00,0x00,0x00,0x0b,0x00,0x06,0x00,0x01,0x00,0x00,0x00,
0x47,0x4c,0x53,0x4c,0x2e,0x73,0x74,0x64,0x2e,0x34,0x35,0x30,
0x00,0x00,0x00,0x00,0x0e,0x00,0x03,0x00,0x00,0x00,0x00,0x00,
0x01,0x00,0x00,0x00,0x0f,0x00,0x09,0x00,0x05,0x00,0x00,0x00,
0x04,0x00,0x00,0x00,0x6d,0x61,0x69,0x6e,0x00,0x00,0x00,0x00,
0x0b,0x00,0x00,0x00,0x14,0x00,0x00,0x00,0x24,0x00,0x00,0x00,
-0x2c,0x00,0x00,0x00,0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,
+0x38,0x00,0x00,0x00,0x10,0x00,0x06,0x00,0x04,0x00,0x00,0x00,
0x11,0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x01,0x00,0x00,0x00,
0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x0b,0x00,0x00,0x00,
0x0b,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,0x48,0x00,0x05,0x00,
@@ -14696,15 +14696,15 @@ unsigned char gelu_f32_data[] = {
0x22,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
0x24,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x47,0x00,0x04,0x00,0x24,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x29,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x35,0x00,0x00,0x00,
0x06,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x48,0x00,0x04,0x00,
-0x2a,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
-0x48,0x00,0x05,0x00,0x2a,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x36,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
+0x48,0x00,0x05,0x00,0x36,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
0x23,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x47,0x00,0x03,0x00,
-0x2a,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
-0x2c,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0x47,0x00,0x04,0x00,0x2c,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x42,0x00,0x00,0x00,
+0x36,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x47,0x00,0x04,0x00,
+0x38,0x00,0x00,0x00,0x22,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x47,0x00,0x04,0x00,0x38,0x00,0x00,0x00,0x21,0x00,0x00,0x00,
+0x01,0x00,0x00,0x00,0x47,0x00,0x04,0x00,0x48,0x00,0x00,0x00,
0x0b,0x00,0x00,0x00,0x19,0x00,0x00,0x00,0x13,0x00,0x02,0x00,
0x02,0x00,0x00,0x00,0x21,0x00,0x03,0x00,0x03,0x00,0x00,0x00,
0x02,0x00,0x00,0x00,0x15,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
@@ -14731,64 +14731,70 @@ unsigned char gelu_f32_data[] = {
0x23,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x22,0x00,0x00,0x00,
0x3b,0x00,0x04,0x00,0x23,0x00,0x00,0x00,0x24,0x00,0x00,0x00,
0x0c,0x00,0x00,0x00,0x20,0x00,0x04,0x00,0x26,0x00,0x00,0x00,
-0x0c,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x1d,0x00,0x03,0x00,
-0x29,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
-0x2a,0x00,0x00,0x00,0x29,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
-0x2b,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,
-0x3b,0x00,0x04,0x00,0x2b,0x00,0x00,0x00,0x2c,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,
+0x11,0x00,0x00,0x00,0x2a,0x00,0x00,0x00,0x2a,0x42,0x4c,0x3f,
+0x2b,0x00,0x04,0x00,0x11,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
+0x00,0x00,0x80,0x3f,0x2b,0x00,0x04,0x00,0x11,0x00,0x00,0x00,
+0x2e,0x00,0x00,0x00,0x13,0x27,0x37,0x3d,0x1d,0x00,0x03,0x00,
+0x35,0x00,0x00,0x00,0x11,0x00,0x00,0x00,0x1e,0x00,0x03,0x00,
+0x36,0x00,0x00,0x00,0x35,0x00,0x00,0x00,0x20,0x00,0x04,0x00,
+0x37,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x36,0x00,0x00,0x00,
+0x3b,0x00,0x04,0x00,0x37,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
0x0c,0x00,0x00,0x00,0x2b,0x00,0x04,0x00,0x11,0x00,0x00,0x00,
-0x2e,0x00,0x00,0x00,0x00,0x00,0x00,0x3f,0x2b,0x00,0x04,0x00,
-0x11,0x00,0x00,0x00,0x31,0x00,0x00,0x00,0x00,0x00,0x80,0x3f,
-0x2b,0x00,0x04,0x00,0x11,0x00,0x00,0x00,0x32,0x00,0x00,0x00,
-0x2a,0x42,0x4c,0x3f,0x2b,0x00,0x04,0x00,0x11,0x00,0x00,0x00,
-0x35,0x00,0x00,0x00,0x13,0x27,0x37,0x3d,0x2b,0x00,0x04,0x00,
-0x06,0x00,0x00,0x00,0x40,0x00,0x00,0x00,0x00,0x02,0x00,0x00,
-0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x2c,0x00,0x06,0x00,0x09,0x00,0x00,0x00,
-0x42,0x00,0x00,0x00,0x40,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
-0x41,0x00,0x00,0x00,0x36,0x00,0x05,0x00,0x02,0x00,0x00,0x00,
-0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x03,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x05,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,
-0x43,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xfb,0x00,0x03,0x00,
-0x0c,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x44,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x0d,0x00,0x00,0x00,
-0x0e,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,
-0x0e,0x00,0x00,0x00,0x41,0x00,0x05,0x00,0x17,0x00,0x00,0x00,
-0x18,0x00,0x00,0x00,0x14,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
-0x18,0x00,0x00,0x00,0xae,0x00,0x05,0x00,0x1a,0x00,0x00,0x00,
-0x1b,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x19,0x00,0x00,0x00,
-0xf7,0x00,0x03,0x00,0x1d,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
-0xfa,0x00,0x04,0x00,0x1b,0x00,0x00,0x00,0x1c,0x00,0x00,0x00,
-0x1d,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x1c,0x00,0x00,0x00,
-0xf9,0x00,0x02,0x00,0x43,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
-0x1d,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x26,0x00,0x00,0x00,
-0x27,0x00,0x00,0x00,0x24,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0x0f,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x11,0x00,0x00,0x00,
-0x28,0x00,0x00,0x00,0x27,0x00,0x00,0x00,0x85,0x00,0x05,0x00,
+0x3a,0x00,0x00,0x00,0x00,0x00,0x00,0x3f,0x2b,0x00,0x04,0x00,
+0x11,0x00,0x00,0x00,0x3d,0x00,0x00,0x00,0x00,0x00,0x00,0x40,
+0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,0x46,0x00,0x00,0x00,
+0x00,0x02,0x00,0x00,0x2b,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x47,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x2c,0x00,0x06,0x00,
+0x09,0x00,0x00,0x00,0x48,0x00,0x00,0x00,0x46,0x00,0x00,0x00,
+0x47,0x00,0x00,0x00,0x47,0x00,0x00,0x00,0x36,0x00,0x05,0x00,
+0x02,0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0x03,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x05,0x00,0x00,0x00,
+0xf7,0x00,0x03,0x00,0x49,0x00,0x00,0x00,0x00,0x00,0x00,0x00,
+0xfb,0x00,0x03,0x00,0x0c,0x00,0x00,0x00,0x4a,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x4a,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x0d,0x00,0x00,0x00,0x0e,0x00,0x00,0x00,0x0b,0x00,0x00,0x00,
+0x0c,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x0f,0x00,0x00,0x00,0x0e,0x00,0x00,0x00,0x41,0x00,0x05,0x00,
+0x17,0x00,0x00,0x00,0x18,0x00,0x00,0x00,0x14,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,0x06,0x00,0x00,0x00,
+0x19,0x00,0x00,0x00,0x18,0x00,0x00,0x00,0xae,0x00,0x05,0x00,
+0x1a,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,
+0x19,0x00,0x00,0x00,0xf7,0x00,0x03,0x00,0x1d,0x00,0x00,0x00,
+0x00,0x00,0x00,0x00,0xfa,0x00,0x04,0x00,0x1b,0x00,0x00,0x00,
+0x1c,0x00,0x00,0x00,0x1d,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,
+0x1c,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x49,0x00,0x00,0x00,
+0xf8,0x00,0x02,0x00,0x1d,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x26,0x00,0x00,0x00,0x27,0x00,0x00,0x00,0x24,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x3d,0x00,0x04,0x00,
+0x11,0x00,0x00,0x00,0x28,0x00,0x00,0x00,0x27,0x00,0x00,0x00,
+0x85,0x00,0x05,0x00,0x11,0x00,0x00,0x00,0x2c,0x00,0x00,0x00,
+0x2a,0x00,0x00,0x00,0x28,0x00,0x00,0x00,0x85,0x00,0x05,0x00,
0x11,0x00,0x00,0x00,0x30,0x00,0x00,0x00,0x2e,0x00,0x00,0x00,
+0x28,0x00,0x00,0x00,0x0c,0x00,0x08,0x00,0x11,0x00,0x00,0x00,
+0x33,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x32,0x00,0x00,0x00,
+0x30,0x00,0x00,0x00,0x28,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,
+0x85,0x00,0x05,0x00,0x11,0x00,0x00,0x00,0x34,0x00,0x00,0x00,
+0x2c,0x00,0x00,0x00,0x33,0x00,0x00,0x00,0x85,0x00,0x05,0x00,
+0x11,0x00,0x00,0x00,0x3c,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
0x28,0x00,0x00,0x00,0x85,0x00,0x05,0x00,0x11,0x00,0x00,0x00,
-0x34,0x00,0x00,0x00,0x32,0x00,0x00,0x00,0x28,0x00,0x00,0x00,
-0x85,0x00,0x05,0x00,0x11,0x00,0x00,0x00,0x37,0x00,0x00,0x00,
-0x35,0x00,0x00,0x00,0x28,0x00,0x00,0x00,0x0c,0x00,0x08,0x00,
-0x11,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,0x01,0x00,0x00,0x00,
-0x32,0x00,0x00,0x00,0x37,0x00,0x00,0x00,0x28,0x00,0x00,0x00,
-0x31,0x00,0x00,0x00,0x85,0x00,0x05,0x00,0x11,0x00,0x00,0x00,
-0x3b,0x00,0x00,0x00,0x34,0x00,0x00,0x00,0x3a,0x00,0x00,0x00,
-0x0c,0x00,0x06,0x00,0x11,0x00,0x00,0x00,0x3c,0x00,0x00,0x00,
-0x01,0x00,0x00,0x00,0x15,0x00,0x00,0x00,0x3b,0x00,0x00,0x00,
-0x81,0x00,0x05,0x00,0x11,0x00,0x00,0x00,0x3d,0x00,0x00,0x00,
-0x31,0x00,0x00,0x00,0x3c,0x00,0x00,0x00,0x85,0x00,0x05,0x00,
-0x11,0x00,0x00,0x00,0x3e,0x00,0x00,0x00,0x30,0x00,0x00,0x00,
-0x3d,0x00,0x00,0x00,0x41,0x00,0x06,0x00,0x26,0x00,0x00,0x00,
-0x3f,0x00,0x00,0x00,0x2c,0x00,0x00,0x00,0x16,0x00,0x00,0x00,
-0x0f,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,0x3f,0x00,0x00,0x00,
-0x3e,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,0x43,0x00,0x00,0x00,
-0xf8,0x00,0x02,0x00,0x43,0x00,0x00,0x00,0xfd,0x00,0x01,0x00,
-0x38,0x00,0x01,0x00,
+0x3f,0x00,0x00,0x00,0x3d,0x00,0x00,0x00,0x34,0x00,0x00,0x00,
+0x0c,0x00,0x06,0x00,0x11,0x00,0x00,0x00,0x40,0x00,0x00,0x00,
+0x01,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x3f,0x00,0x00,0x00,
+0x81,0x00,0x05,0x00,0x11,0x00,0x00,0x00,0x41,0x00,0x00,0x00,
+0x40,0x00,0x00,0x00,0x2d,0x00,0x00,0x00,0x88,0x00,0x05,0x00,
+0x11,0x00,0x00,0x00,0x42,0x00,0x00,0x00,0x3d,0x00,0x00,0x00,
+0x41,0x00,0x00,0x00,0x83,0x00,0x05,0x00,0x11,0x00,0x00,0x00,
+0x43,0x00,0x00,0x00,0x3d,0x00,0x00,0x00,0x42,0x00,0x00,0x00,
+0x85,0x00,0x05,0x00,0x11,0x00,0x00,0x00,0x44,0x00,0x00,0x00,
+0x3c,0x00,0x00,0x00,0x43,0x00,0x00,0x00,0x41,0x00,0x06,0x00,
+0x26,0x00,0x00,0x00,0x45,0x00,0x00,0x00,0x38,0x00,0x00,0x00,
+0x16,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x3e,0x00,0x03,0x00,
+0x45,0x00,0x00,0x00,0x44,0x00,0x00,0x00,0xf9,0x00,0x02,0x00,
+0x49,0x00,0x00,0x00,0xf8,0x00,0x02,0x00,0x49,0x00,0x00,0x00,
+0xfd,0x00,0x01,0x00,0x38,0x00,0x01,0x00,
};
-const uint64_t gelu_f32_len = 1408;
+const uint64_t gelu_f32_len = 1484;
unsigned char get_rows_f16_data[] = {
0x03,0x02,0x23,0x07,0x00,0x05,0x01,0x00,0x0b,0x00,0x0d,0x00,
diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp
index bccc40bf..b1e0006b 100644
--- a/ggml-vulkan.cpp
+++ b/ggml-vulkan.cpp
@@ -2876,6 +2876,9 @@ static void ggml_vk_op_f32(vk_context * ctx, const ggml_tensor * src0, const ggm
x_sz = ggml_nbytes(src0);
d_sz = ggml_nbytes(dst);
+ if (extra_src0->offset + x_sz >= d_X->size) {
+ x_sz = VK_WHOLE_SIZE;
+ }
if (extra->offset + d_sz >= d_D->size) {
d_sz = VK_WHOLE_SIZE;
}
@@ -2911,12 +2914,16 @@ static void ggml_vk_op_f32(vk_context * ctx, const ggml_tensor * src0, const ggm
break;
}
- x_sz *= ne02 * ne03;
- if (y_sz != VK_WHOLE_SIZE) {
- y_sz *= ne12 * ne13;
- }
if (op != GGML_OP_CPY) {
- d_sz *= ne02 * ne03;
+ if (x_sz != VK_WHOLE_SIZE) {
+ x_sz *= ne02 * ne03;
+ }
+ if (y_sz != VK_WHOLE_SIZE) {
+ y_sz *= ne12 * ne13;
+ }
+ if (d_sz != VK_WHOLE_SIZE) {
+ d_sz *= ne02 * ne03;
+ }
}
if (!use_src1 && op == GGML_OP_SOFT_MAX) {
diff --git a/ggml_vk_generate_shaders.py b/ggml_vk_generate_shaders.py
index 6b1b82bf..67981a75 100644
--- a/ggml_vk_generate_shaders.py
+++ b/ggml_vk_generate_shaders.py
@@ -1689,7 +1689,8 @@ void main() {
}
const float xi = float(data_a[i]);
- data_d[i] = D_TYPE(0.5f*xi*(1.0f + tanh(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi))));
+ const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
+ data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
}
"""