diff options
Diffstat (limited to 'ggml/src/vulkan-shaders')
-rw-r--r-- | ggml/src/vulkan-shaders/fused_rms_norm.comp | 54 | ||||
-rw-r--r-- | ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp | 1 |
2 files changed, 55 insertions, 0 deletions
diff --git a/ggml/src/vulkan-shaders/fused_rms_norm.comp b/ggml/src/vulkan-shaders/fused_rms_norm.comp new file mode 100644 index 00000000..fcea166c --- /dev/null +++ b/ggml/src/vulkan-shaders/fused_rms_norm.comp @@ -0,0 +1,54 @@ +#version 450 + +#include "generic_binary_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = gl_WorkGroupID.x; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + const uint tid = gl_LocalInvocationID.x; + + const uint stride_row_a = p.nb01; + const uint stride_channel_a = p.nb02; + const uint stride_sample_a = p.nb03; + + uint32_t a_offset = samp*stride_sample_a + channel*stride_channel_a + row*stride_row_a; + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sumf = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); + sumf += xi * xi; + } + + sum[tid] = sumf; + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[col])); + } +} diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp index 0f244dea..d622f1bd 100644 --- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp @@ -498,6 +498,7 @@ void process_shaders() { string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); |