diff options
-rw-r--r-- | ggml/src/ggml-vulkan.cpp | 40 | ||||
-rw-r--r-- | ggml/src/vulkan-shaders/fused_rms_norm.comp | 54 | ||||
-rw-r--r-- | ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp | 1 | ||||
-rw-r--r-- | src/llama.cpp | 7 |
4 files changed, 96 insertions, 6 deletions
diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index c0bdfb7b..bd17fb26 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -431,6 +431,7 @@ struct vk_device_struct { vk_pipeline pipeline_norm_f32; vk_pipeline pipeline_group_norm_f32; vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_fused_rms_norm_f32; vk_pipeline pipeline_rms_norm_back_f32; // [src/dst 0=fp32,1=fp16] @@ -2653,6 +2654,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_fused_rms_norm_f32, "fused_rms_norm_f32", fused_rms_norm_f32_len, fused_rms_norm_f32_data, "main", 3, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -6381,6 +6383,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rms_norm_f32; } return nullptr; + case GGML_OP_FUSED_RMS_NORM: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_fused_rms_norm_f32; + } + return nullptr; case GGML_OP_RMS_NORM_BACK: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_rms_norm_back_f32; @@ -6521,6 +6528,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { case GGML_OP_REPEAT_BACK: case GGML_OP_ROPE: case GGML_OP_RMS_NORM: + case GGML_OP_FUSED_RMS_NORM: case GGML_OP_IM2COL: return true; default: @@ -6751,6 +6759,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; break; + case GGML_OP_FUSED_RMS_NORM: + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + break; + case GGML_OP_SUM: // We use GGML_OP_SUM_ROWS with 1 row. elements = { 1, 1, 1 }; @@ -7173,6 +7185,24 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, }, dryrun); } +static void ggml_vk_fused_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); + GGML_ASSERT(src1->ne[0] == src0->ne[0]); + + ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_FUSED_RMS_NORM, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], 1u, 1u, 1u, (uint32_t)src1->nb[0] / src1_type_size, 0u, 0u, 0u, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, 0, + }, dryrun); +} + static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { float * op_params = (float *)dst->op_params; ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); @@ -8386,6 +8416,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_FUSED_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -8444,6 +8475,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_FUSED_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_UNARY: case GGML_OP_DIAG_MASK_INF: @@ -8551,6 +8583,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); break; + case GGML_OP_FUSED_RMS_NORM: + ggml_vk_fused_rms_norm(ctx, compute_ctx, src0, src1, node, dryrun); + + break; case GGML_OP_RMS_NORM_BACK: ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); @@ -8703,6 +8739,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_NORM: case GGML_OP_GROUP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_FUSED_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -9625,6 +9662,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: case GGML_OP_RMS_NORM: + case GGML_OP_FUSED_RMS_NORM: return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: @@ -10064,6 +10102,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); } else if (tensor->op == GGML_OP_RMS_NORM) { tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_FUSED_RMS_NORM) { + tensor_clone = ggml_fused_rms_norm(ggml_ctx, src_clone[0], src_clone[1], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { const float eps = ((float *) tensor->op_params)[0]; tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); diff --git a/ggml/src/vulkan-shaders/fused_rms_norm.comp b/ggml/src/vulkan-shaders/fused_rms_norm.comp new file mode 100644 index 00000000..fcea166c --- /dev/null +++ b/ggml/src/vulkan-shaders/fused_rms_norm.comp @@ -0,0 +1,54 @@ +#version 450 + +#include "generic_binary_head.comp" +#include "types.comp" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = gl_WorkGroupID.x; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + const uint tid = gl_LocalInvocationID.x; + + const uint stride_row_a = p.nb01; + const uint stride_channel_a = p.nb02; + const uint stride_sample_a = p.nb03; + + uint32_t a_offset = samp*stride_sample_a + channel*stride_channel_a + row*stride_row_a; + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sumf = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]); + sumf += xi * xi; + } + + sum[tid] = sumf; + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[col])); + } +} diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp index 0f244dea..d622f1bd 100644 --- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp @@ -498,6 +498,7 @@ void process_shaders() { string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); diff --git a/src/llama.cpp b/src/llama.cpp index c2769e32..794dcca6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9597,12 +9597,7 @@ static struct ggml_tensor * llm_build_norm( const llm_build_cb & cb, int il, float scale_eps = 1) { -#ifdef GGML_USE_VULKAN - constexpr bool use_fused_rms_norm = false; -#else - constexpr bool use_fused_rms_norm = true; -#endif - if (use_fused_rms_norm && type == LLM_NORM_RMS && mw) { + if (type == LLM_NORM_RMS && mw) { cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps); if (mb) { cb(cur, "fused_norm", il); |