summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml/src/ggml-vulkan.cpp40
-rw-r--r--ggml/src/vulkan-shaders/fused_rms_norm.comp54
-rw-r--r--ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp1
-rw-r--r--src/llama.cpp7
4 files changed, 96 insertions, 6 deletions
diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp
index c0bdfb7b..bd17fb26 100644
--- a/ggml/src/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan.cpp
@@ -431,6 +431,7 @@ struct vk_device_struct {
vk_pipeline pipeline_norm_f32;
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
+ vk_pipeline pipeline_fused_rms_norm_f32;
vk_pipeline pipeline_rms_norm_back_f32;
// [src/dst 0=fp32,1=fp16]
@@ -2653,6 +2654,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_fused_rms_norm_f32, "fused_rms_norm_f32", fused_rms_norm_f32_len, fused_rms_norm_f32_data, "main", 3, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -6381,6 +6383,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rms_norm_f32;
}
return nullptr;
+ case GGML_OP_FUSED_RMS_NORM:
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_fused_rms_norm_f32;
+ }
+ return nullptr;
case GGML_OP_RMS_NORM_BACK:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_rms_norm_back_f32;
@@ -6521,6 +6528,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_REPEAT_BACK:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
+ case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_IM2COL:
return true;
default:
@@ -6751,6 +6759,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
break;
+ case GGML_OP_FUSED_RMS_NORM:
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
+ break;
+
case GGML_OP_SUM:
// We use GGML_OP_SUM_ROWS with 1 row.
elements = { 1, 1, 1 };
@@ -7173,6 +7185,24 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
}, dryrun);
}
+static void ggml_vk_fused_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+ float * op_params = (float *)dst->op_params;
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+ GGML_ASSERT(src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1);
+ GGML_ASSERT(src1->ne[0] == src0->ne[0]);
+
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_FUSED_RMS_NORM, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)src1->ne[0], 1u, 1u, 1u, (uint32_t)src1->nb[0] / src1_type_size, 0u, 0u, 0u,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ op_params[0], 0.0f, 0,
+ }, dryrun);
+}
+
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
@@ -8386,6 +8416,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
+ case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
@@ -8444,6 +8475,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
+ case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_UNARY:
case GGML_OP_DIAG_MASK_INF:
@@ -8551,6 +8583,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
break;
+ case GGML_OP_FUSED_RMS_NORM:
+ ggml_vk_fused_rms_norm(ctx, compute_ctx, src0, src1, node, dryrun);
+
+ break;
case GGML_OP_RMS_NORM_BACK:
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -8703,6 +8739,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
+ case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
@@ -9625,6 +9662,7 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_RMS_NORM:
+ case GGML_OP_FUSED_RMS_NORM:
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
@@ -10064,6 +10102,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
} else if (tensor->op == GGML_OP_RMS_NORM) {
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
+ } else if (tensor->op == GGML_OP_FUSED_RMS_NORM) {
+ tensor_clone = ggml_fused_rms_norm(ggml_ctx, src_clone[0], src_clone[1], *(float *)tensor->op_params);
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
const float eps = ((float *) tensor->op_params)[0];
tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
diff --git a/ggml/src/vulkan-shaders/fused_rms_norm.comp b/ggml/src/vulkan-shaders/fused_rms_norm.comp
new file mode 100644
index 00000000..fcea166c
--- /dev/null
+++ b/ggml/src/vulkan-shaders/fused_rms_norm.comp
@@ -0,0 +1,54 @@
+#version 450
+
+#include "generic_binary_head.comp"
+#include "types.comp"
+
+#extension GL_EXT_control_flow_attributes : enable
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+shared FLOAT_TYPE sum[BLOCK_SIZE];
+
+void main() {
+ const uint ncols = p.ne00;
+ const uint nrows = gl_NumWorkGroups.x;
+ const uint nchannels = gl_NumWorkGroups.y;
+
+ const uint row = gl_WorkGroupID.x;
+ const uint channel = gl_WorkGroupID.y;
+ const uint samp = gl_WorkGroupID.z;
+ const uint tid = gl_LocalInvocationID.x;
+
+ const uint stride_row_a = p.nb01;
+ const uint stride_channel_a = p.nb02;
+ const uint stride_sample_a = p.nb03;
+
+ uint32_t a_offset = samp*stride_sample_a + channel*stride_channel_a + row*stride_row_a;
+ uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
+
+ FLOAT_TYPE sumf = FLOAT_TYPE(0.0f);
+
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
+ sumf += xi * xi;
+ }
+
+ sum[tid] = sumf;
+
+ // sum up partial sums and write back result
+ barrier();
+ [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ sum[tid] += sum[tid + s];
+ }
+ barrier();
+ }
+
+ const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
+ const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
+
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[col]));
+ }
+}
diff --git a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
index 0f244dea..d622f1bd 100644
--- a/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
+++ b/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp
@@ -498,6 +498,7 @@ void process_shaders() {
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("fused_rms_norm_f32", "fused_rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
diff --git a/src/llama.cpp b/src/llama.cpp
index c2769e32..794dcca6 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -9597,12 +9597,7 @@ static struct ggml_tensor * llm_build_norm(
const llm_build_cb & cb,
int il, float scale_eps = 1) {
-#ifdef GGML_USE_VULKAN
- constexpr bool use_fused_rms_norm = false;
-#else
- constexpr bool use_fused_rms_norm = true;
-#endif
- if (use_fused_rms_norm && type == LLM_NORM_RMS && mw) {
+ if (type == LLM_NORM_RMS && mw) {
cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps);
if (mb) {
cb(cur, "fused_norm", il);