summaryrefslogtreecommitdiff
path: root/ggml/src/vulkan-shaders
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/vulkan-shaders')
-rw-r--r--ggml/src/vulkan-shaders/fused_rms_norm.comp54
-rw-r--r--ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp1
2 files changed, 55 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/fused_rms_norm.comp b/ggml/src/vulkan-shaders/fused_rms_norm.comp
new file mode 100644
index 00000000..fcea166c
--- /dev/null
+++ b/ggml/src/vulkan-shaders/fused_rms_norm.comp
@@ -0,0 +1,54 @@
+#version 450
+
+#include "generic_binary_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+shared FLOAT_TYPE sum[BLOCK_SIZE];
+
+void main() {
+ const uint ncols = p.ne00;
+ const uint nrows = gl_NumWorkGroups.x;
+ const uint nchannels = gl_NumWorkGroups.y;
+
+ const uint row = gl_WorkGroupID.x;
+ const uint channel = gl_WorkGroupID.y;
+ const uint samp = gl_WorkGroupID.z;
+ const uint tid = gl_LocalInvocationID.x;
+
+ const uint stride_row_a = p.nb01;
+ const uint stride_channel_a = p.nb02;
+ const uint stride_sample_a = p.nb03;
+
+ uint32_t a_offset = samp*stride_sample_a + channel*stride_channel_a + row*stride_row_a;
+ uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
+
+ FLOAT_TYPE sumf = FLOAT_TYPE(0.0f);
+
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
+ sumf += xi * xi;
+ }
+
+ sum[tid] = sumf;
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ sum[tid] += sum[tid + s];
+ }
+ barrier();
+ }
+
+ const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
+ const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
+
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[col]));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
index 0f244dea..d622f1bd 100644
--- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -498,6 +498,7 @@ void process_shaders() {
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});